Skip to content

Commit

Permalink
[src] Feature bank feature extraction using CUDA (#3544)
Browse files Browse the repository at this point in the history
Following this change, both MFCC and fbank run
through a single code path with parameters
(use_power, use_log_fbank and use_dct) controlling
the flow.

CudaMfcc has been renamed to CudaSpectralFeatures.

It contains an MfccOptions structure which contains
FrameOptions and MelOptions. It can be initialized
either with an MfccOptions object or an FbankOptions
object.

Compared with CudaMfccOptions, CudaSpectralOptions also
contains these parameters

use_dct  - switches on the discrete cosine and lifter
use_log_fbank - takes the log of the MEL banks values
use_power - uses power in place of abs(amplitude)

Each of these is defaulted on for MFCC. For fbank,
use_dct is set to false. The others are set by user
parameters.

Also added a unit test for CUDA Fbank
(cudafeatbin/compute-fbank-feats-cuda).
  • Loading branch information
LeviBarnes authored and danpovey committed Aug 28, 2019
1 parent 61bc12e commit 9a83681
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 53 deletions.
2 changes: 1 addition & 1 deletion src/cudafeat/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ifeq ($(CUDA), true)
TESTFILES =

ifeq ($(CUDA), true)
OBJFILES += feature-window-cuda.o feature-mfcc-cuda.o feature-online-cmvn-cuda.o \
OBJFILES += feature-window-cuda.o feature-spectral-cuda.o feature-online-cmvn-cuda.o \
online-ivector-feature-cuda-kernels.o online-ivector-feature-cuda.o \
online-cuda-feature-pipeline.o
endif
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// cudafeature/feature-mfcc-cuda.cu
// cudafeature/feature-spectral-cuda.cu
//
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Justin Luitjens
Expand All @@ -20,7 +20,7 @@
#include <cub/cub.cuh>
#endif

#include "cudafeat/feature-mfcc-cuda.h"
#include "cudafeat/feature-spectral-cuda.h"
#include "cudamatrix/cu-rand.h"

// Each thread block processes a unique frame
Expand Down Expand Up @@ -60,7 +60,8 @@ __global__ void apply_lifter_and_floor_energy(
// Threads in the same block compute the row collaboratively.
// This kernel must be called out of place (A_in!=A_out).
__global__ void power_spectrum_kernel(int row_length, float *A_in, int32_t ldi,
float *A_out, int32_t ldo) {
float *A_out, int32_t ldo,
bool use_power) {
int thread_id = threadIdx.x;
int block_id = blockIdx.x;
float *Ar = A_in + block_id * ldi;
Expand All @@ -73,7 +74,11 @@ __global__ void power_spectrum_kernel(int row_length, float *A_in, int32_t ldi,

float2 val = reinterpret_cast<float2 *>(Ar)[idx];
float ret = val.x * val.x + val.y * val.y;
Aw[idx] = ret;
if (use_power) {
Aw[idx] = ret;
} else {
Aw[idx] = sqrtf(ret);
}
}

// handle special case
Expand All @@ -84,17 +89,23 @@ __global__ void power_spectrum_kernel(int row_length, float *A_in, int32_t ldi,
// internal implementation
float im = Ar[row_length];

Aw[0] = real * real;
Aw[half_length] = im * im;
if (use_power) {
Aw[0] = real * real;
Aw[half_length] = im * im;
} else {
Aw[0] = fabs(real);
Aw[half_length] = fabs(im);
}
}
}

// Expects to be called with 32x8 sized thread block.
// LDB: Adding use_log flag
__global__ void mel_banks_compute_kernel(int32_t num_frames, float energy_floor,
int32 *offsets, int32 *sizes,
float **vecs, const float *feats,
int32_t ldf, float *mels,
int32_t ldm) {
int32_t ldf, float *mels, int32_t ldm,
bool use_log) {
// Specialize WarpReduce for type float
typedef cub::WarpReduce<float> WarpReduce;
// Allocate WarpReduce shared memory for 8 warps
Expand Down Expand Up @@ -125,10 +136,14 @@ __global__ void mel_banks_compute_kernel(int32_t num_frames, float energy_floor,
// Sum in cub
sum = WarpReduce(temp_storage[wid]).Sum(sum);
if (tid == 0) {
// avoid log of zero
if (sum < energy_floor) sum = energy_floor;
float val = logf(sum);
mels[frame * ldm + bin] = val;
if (use_log) {
// avoid log of zero
if (sum < energy_floor) sum = energy_floor;
float val = logf(sum);
mels[frame * ldm + bin] = val;
} else {
mels[frame * ldm + bin] = sum;
}
}
}

Expand Down Expand Up @@ -341,11 +356,11 @@ __global__ void dot_log_kernel(int32_t num_frames, int32_t frame_length,

namespace kaldi {

CudaMfcc::CudaMfcc(const MfccOptions &opts)
: MfccComputer(opts),
CudaSpectralFeatures::CudaSpectralFeatures(const CudaSpectralFeatureOptions &opts)
: MfccComputer(opts.mfcc_opts),
cu_lifter_coeffs_(lifter_coeffs_),
cu_dct_matrix_(dct_matrix_),
window_function_(opts.frame_opts) {
window_function_(opts.mfcc_opts.frame_opts) {
const MelBanks *mel_banks = GetMelBanks(1.0);
const std::vector<std::pair<int32, Vector<BaseFloat>>> &bins =
mel_banks->GetBins();
Expand Down Expand Up @@ -376,8 +391,8 @@ CudaMfcc::CudaMfcc(const MfccOptions &opts)
cudaMemcpyHostToDevice, cudaStreamPerThread));
CU_SAFE_CALL(cudaStreamSynchronize(cudaStreamPerThread));

frame_length_ = opts.frame_opts.WindowSize();
padded_length_ = opts.frame_opts.PaddedWindowSize();
frame_length_ = opts.mfcc_opts.frame_opts.WindowSize();
padded_length_ = opts.mfcc_opts.frame_opts.PaddedWindowSize();
fft_length_ = padded_length_ / 2; // + 1;
fft_size_ = 800;

Expand All @@ -395,12 +410,13 @@ CudaMfcc::CudaMfcc(const MfccOptions &opts)
cufftPlanMany(&plan_, 1, &padded_length_, NULL, 1, stride_, NULL, 1,
tmp_stride_ / 2, CUFFT_R2C, fft_size_);
cufftSetStream(plan_, cudaStreamPerThread);
cumfcc_opts_ = opts;
}

// ExtractWindow extracts a windowed frame of waveform with a power-of-two,
// padded size. It does mean subtraction, pre-emphasis and dithering as
// requested.
void CudaMfcc::ExtractWindows(int32_t num_frames, int64 sample_offset,
void CudaSpectralFeatures::ExtractWindows(int32_t num_frames, int64 sample_offset,
const CuVectorBase<BaseFloat> &wave,
const FrameExtractionOptions &opts) {
KALDI_ASSERT(sample_offset >= 0 && wave.Dim() != 0);
Expand All @@ -415,7 +431,7 @@ void CudaMfcc::ExtractWindows(int32_t num_frames, int64 sample_offset,
CU_SAFE_CALL(cudaGetLastError());
}

void CudaMfcc::ProcessWindows(int num_frames,
void CudaSpectralFeatures::ProcessWindows(int num_frames,
const FrameExtractionOptions &opts,
CuVectorBase<BaseFloat> *log_energy_pre_window) {
if (num_frames == 0) return;
Expand All @@ -433,15 +449,16 @@ void CudaMfcc::ProcessWindows(int num_frames,
CU_SAFE_CALL(cudaGetLastError());
}

void CudaMfcc::ComputeFinalFeatures(int num_frames, BaseFloat vtln_wrap,
void CudaSpectralFeatures::ComputeFinalFeatures(int num_frames, BaseFloat vtln_wrap,
CuVector<BaseFloat> *cu_signal_log_energy,
CuMatrix<BaseFloat> *cu_features) {
MfccOptions mfcc_opts = cumfcc_opts_.mfcc_opts;
Vector<float> tmp;
assert(opts_.htk_compat == false);
assert(mfcc_opts.htk_compat == false);

if (num_frames == 0) return;

if (opts_.use_energy && !opts_.raw_energy) {
if (mfcc_opts.use_energy && !mfcc_opts.raw_energy) {
dot_log_kernel<<<num_frames, CU1DBLOCK>>>(
num_frames, cu_windows_.NumCols(), cu_windows_.Data(),
cu_windows_.Stride(), cu_signal_log_energy->Data());
Expand All @@ -465,7 +482,7 @@ void CudaMfcc::ComputeFinalFeatures(int num_frames, BaseFloat vtln_wrap,

power_spectrum_kernel<<<num_frames, CU1DBLOCK>>>(
padded_length_, tmp_window_.Data(), tmp_window_.Stride(),
power_spectrum.Data(), power_spectrum.Stride());
power_spectrum.Data(), power_spectrum.Stride(), cumfcc_opts_.use_power);
CU_SAFE_CALL(cudaGetLastError());

// mel banks
Expand All @@ -476,24 +493,32 @@ void CudaMfcc::ComputeFinalFeatures(int num_frames, BaseFloat vtln_wrap,
mel_banks_compute_kernel<<<mel_blocks, mel_threads>>>(
num_frames, std::numeric_limits<float>::epsilon(), offsets_, sizes_,
vecs_, power_spectrum.Data(), power_spectrum.Stride(),
cu_mel_energies_.Data(), cu_mel_energies_.Stride());
cu_mel_energies_.Data(), cu_mel_energies_.Stride(),
cumfcc_opts_.use_log_fbank);
CU_SAFE_CALL(cudaGetLastError());

// dct transform
cu_features->AddMatMat(1.0, cu_mel_energies_, kNoTrans, cu_dct_matrix_,
kTrans, 0.0);

apply_lifter_and_floor_energy<<<num_frames, CU1DBLOCK>>>(
cu_features->NumRows(), cu_features->NumCols(), opts_.cepstral_lifter,
opts_.use_energy, opts_.energy_floor, cu_signal_log_energy->Data(),
cu_lifter_coeffs_.Data(), cu_features->Data(), cu_features->Stride());
if (cumfcc_opts_.use_dct) {
cu_features->AddMatMat(1.0, cu_mel_energies_, kNoTrans, cu_dct_matrix_,
kTrans, 0.0);

apply_lifter_and_floor_energy<<<num_frames, CU1DBLOCK>>>(
cu_features->NumRows(), cu_features->NumCols(),
mfcc_opts.cepstral_lifter, mfcc_opts.use_energy,
mfcc_opts.energy_floor, cu_signal_log_energy->Data(),
cu_lifter_coeffs_.Data(), cu_features->Data(), cu_features->Stride());
} else {
cudaMemcpyAsync(cu_features->Data(), cu_mel_energies_.Data(),
sizeof(BaseFloat) * num_frames * cu_features->Stride(),
cudaMemcpyDeviceToDevice, cudaStreamPerThread);
}
CU_SAFE_CALL(cudaGetLastError());
}

void CudaMfcc::ComputeFeatures(const CuVectorBase<BaseFloat> &cu_wave,
void CudaSpectralFeatures::ComputeFeatures(const CuVectorBase<BaseFloat> &cu_wave,
BaseFloat sample_freq, BaseFloat vtln_warp,
CuMatrix<BaseFloat> *cu_features) {
nvtxRangePushA("CudaMfcc::ComputeFeatures");
nvtxRangePushA("CudaSpectralFeatures::ComputeFeatures");
const FrameExtractionOptions &frame_opts = GetFrameOptions();
int num_frames = NumFrames(cu_wave.Dim(), frame_opts, true);
// compute fft frames by rounding up to a multiple of fft_size_
Expand Down Expand Up @@ -533,7 +558,7 @@ void CudaMfcc::ComputeFeatures(const CuVectorBase<BaseFloat> &cu_wave,

nvtxRangePop();
}
CudaMfcc::~CudaMfcc() {
CudaSpectralFeatures::~CudaSpectralFeatures() {
delete[] cu_vecs_;
CuDevice::Instantiate().Free(vecs_);
CuDevice::Instantiate().Free(offsets_);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// cudafeat/feature-mfcc-cuda.h
// cudafeat/feature-spectral-cuda.h
//
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Justin Luitjens
Expand All @@ -25,20 +25,65 @@
#include "cudafeat/feature-window-cuda.h"
#include "cudamatrix/cu-matrix.h"
#include "cudamatrix/cu-vector.h"
#include "feat/feature-fbank.h"
#include "feat/feature-mfcc.h"

namespace kaldi {
// This class implements MFCC computation in CUDA.
enum SpectralFeatureType {MFCC, FBANK};
struct CudaSpectralFeatureOptions {
MfccOptions mfcc_opts;
bool use_log_fbank; // LDB: Adding these two to enable fbank and mfcc
bool use_power; // to use the same code path for GPU (and CPU?)
bool use_dct; // LDB: Adding this so that fbank can run w/o applying dct
SpectralFeatureType feature_type;
CudaSpectralFeatureOptions(MfccOptions opts_in)
: mfcc_opts(opts_in),
use_log_fbank(true),
use_power(true),
use_dct(true),
feature_type(MFCC) {}
CudaSpectralFeatureOptions(FbankOptions opts){
mfcc_opts.frame_opts = opts.frame_opts;
mfcc_opts.mel_opts = opts.mel_opts;
mfcc_opts.use_energy = opts.use_energy;
mfcc_opts.energy_floor = opts.energy_floor;
mfcc_opts.raw_energy = opts.raw_energy;
mfcc_opts.htk_compat = opts.htk_compat;
mfcc_opts.cepstral_lifter = 0.0f;
use_log_fbank = opts.use_log_fbank;
use_power = opts.use_power;
use_dct = false;
feature_type = FBANK;
}
// Default is MFCC
CudaSpectralFeatureOptions() : use_log_fbank(true),
use_power(true),
use_dct(true),
feature_type(MFCC) {}

};
// This class implements MFCC and Fbank computation in CUDA.
// It takes input from device memory and outputs to
// device memory. It also does no synchronization.
class CudaMfcc : public MfccComputer {
class CudaSpectralFeatures : public MfccComputer {
public:
void ComputeFeatures(const CuVectorBase<BaseFloat> &cu_wave,
BaseFloat sample_freq, BaseFloat vtln_warp,
CuMatrix<BaseFloat> *cu_features);

CudaMfcc(const MfccOptions &opts);
~CudaMfcc();
CudaSpectralFeatures(const CudaSpectralFeatureOptions &opts);
~CudaSpectralFeatures();
CudaSpectralFeatureOptions cumfcc_opts_;
int32 Dim()
// The dimension of the output is different for MFCC and Fbank.
// This returns the appropriate value depending on the feature
// extraction algorithm
{
if (cumfcc_opts_.feature_type == MFCC) return MfccComputer::Dim();
//If we're running fbank, we need to set the dimension right
else return cumfcc_opts_.mfcc_opts.mel_opts.num_bins +
(cumfcc_opts_.mfcc_opts.use_energy ? 1 : 0);
}

private:
void ExtractWindows(int32 num_frames, int64 sample_offset,
Expand Down
16 changes: 9 additions & 7 deletions src/cudafeat/online-cuda-feature-pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ namespace kaldi {

OnlineCudaFeaturePipeline::OnlineCudaFeaturePipeline(
const OnlineNnet2FeaturePipelineConfig &config)
: info_(config), mfcc(NULL), ivector(NULL) {
: info_(config), spectral_feat(NULL), ivector(NULL) {
if (info_.feature_type == "mfcc") {
mfcc = new CudaMfcc(info_.mfcc_opts);
spectral_feat = new CudaSpectralFeatures(info_.mfcc_opts);
}
if (info_.feature_type == "fbank") {
spectral_feat = new CudaSpectralFeatures(info_.fbank_opts);
}

if (info_.use_ivectors) {
Expand All @@ -43,27 +46,26 @@ OnlineCudaFeaturePipeline::OnlineCudaFeaturePipeline(
}

OnlineCudaFeaturePipeline::~OnlineCudaFeaturePipeline() {
if (mfcc != NULL) delete mfcc;
if (spectral_feat = NULL) delete spectral_feat;
if (ivector != NULL) delete ivector;
}

void OnlineCudaFeaturePipeline::ComputeFeatures(
const CuVectorBase<BaseFloat> &cu_wave, BaseFloat sample_freq,
CuMatrix<BaseFloat> *input_features,
CuVector<BaseFloat> *ivector_features) {
if (info_.feature_type == "mfcc") {
if (info_.feature_type == "mfcc" || info_.feature_type == "fbank") {
// Fbank called via the MFCC codepath
// MFCC
float vtln_warp = 1.0;
mfcc->ComputeFeatures(cu_wave, sample_freq, vtln_warp, input_features);
spectral_feat->ComputeFeatures(cu_wave, sample_freq, vtln_warp, input_features);
} else {
KALDI_ASSERT(false);
}

// Ivector
if (info_.use_ivectors && ivector_features != NULL) {
ivector->GetIvector(*input_features, ivector_features);
} else {
KALDI_ASSERT(false);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/cudafeat/online-cuda-feature-pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include <vector>

#include "base/kaldi-error.h"
#include "cudafeat/feature-mfcc-cuda.h"
#include "cudafeat/feature-spectral-cuda.h"
#include "cudafeat/online-ivector-feature-cuda.h"
#include "matrix/matrix-lib.h"
#include "online2/online-nnet2-feature-pipeline.h"
Expand All @@ -47,7 +47,7 @@ class OnlineCudaFeaturePipeline {

private:
OnlineNnet2FeaturePipelineInfo info_;
CudaMfcc *mfcc;
CudaSpectralFeatures *spectral_feat;
IvectorExtractorFastCuda *ivector;
};
} // namespace kaldi
Expand Down
2 changes: 1 addition & 1 deletion src/cudafeatbin/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ LDLIBS += $(CUDA_LDLIBS)
BINFILES =

ifeq ($(CUDA), true)
BINFILES += compute-mfcc-feats-cuda apply-cmvn-online-cuda compute-online-feats-cuda
BINFILES += compute-mfcc-feats-cuda apply-cmvn-online-cuda compute-online-feats-cuda compute-fbank-feats-cuda
endif

OBJFILES =
Expand Down
Loading

0 comments on commit 9a83681

Please sign in to comment.