Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing the propagation of allow_{upsample,downsample} #3139

Merged
merged 2 commits into from
Mar 18, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 74 additions & 24 deletions src/feat/online-feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

namespace kaldi {

RecyclingVector::RecyclingVector(int items_to_hold) :
RecyclingVector::RecyclingVector(int items_to_hold):
items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold),
first_available_index_(0) {
}
Expand All @@ -38,7 +38,8 @@ RecyclingVector::~RecyclingVector() {
Vector<BaseFloat> *RecyclingVector::At(int index) const {
if (index < first_available_index_) {
KALDI_ERR << "Attempted to retrieve feature vector that was "
"already removed by the RecyclingVector (index = " << index << "; "
"already removed by the RecyclingVector (index = "
<< index << "; "
<< "first_available_index = " << first_available_index_ << "; "
<< "size = " << Size() << ")";
}
Expand All @@ -59,43 +60,93 @@ int RecyclingVector::Size() const {
return first_available_index_ + items_.size();
}


template<class C>
template <class C>
void OnlineGenericBaseFeature<C>::GetFrame(int32 frame,
VectorBase<BaseFloat> *feat) {
feat->CopyFromVec(*(features_.At(frame)));
};

template<class C>
template <class C>
OnlineGenericBaseFeature<C>::OnlineGenericBaseFeature(
const typename C::Options &opts):
computer_(opts), window_function_(computer_.GetFrameOptions()),
features_(opts.frame_opts.max_feature_vectors),
input_finished_(false), waveform_offset_(0) { }

template<class C>
void OnlineGenericBaseFeature<C>::AcceptWaveform(BaseFloat sampling_rate,
const VectorBase<BaseFloat> &waveform) {

template <class C>
void OnlineGenericBaseFeature<C>::MaybeCreateResampler(
BaseFloat sampling_rate) {
BaseFloat expected_sampling_rate = computer_.GetFrameOptions().samp_freq;
if (sampling_rate != expected_sampling_rate)

if (resampler_ != nullptr) {
KALDI_ASSERT(resampler_->GetInputSamplingRate() == sampling_rate);
KALDI_ASSERT(resampler_->GetOutputSamplingRate() == expected_sampling_rate);
} else if (((sampling_rate > expected_sampling_rate) &&
!computer_.GetFrameOptions().allow_downsample) ||
((sampling_rate > expected_sampling_rate) &&
!computer_.GetFrameOptions().allow_upsample)) {
resampler_.reset(new LinearResample(
sampling_rate, expected_sampling_rate,
std::min(sampling_rate / 2, expected_sampling_rate / 2), 6));
} else if (sampling_rate != expected_sampling_rate) {
KALDI_ERR << "Sampling frequency mismatch, expected "
<< expected_sampling_rate << ", got " << sampling_rate;
if (waveform.Dim() == 0)
<< expected_sampling_rate << ", got " << sampling_rate
<< "\nPerhaps you want to use the options "
"--allow_{upsample,downsample}";
}
}

template <class C>
void OnlineGenericBaseFeature<C>::InputFinished() {
if (resampler_ != nullptr) {
Vector<BaseFloat> appended_wave;
Vector<BaseFloat> resampled_wave;
resampler_->Resample(appended_wave, true, &resampled_wave);

if (waveform_remainder_.Dim() != 0)
appended_wave.Range(0, waveform_remainder_.Dim())
.CopyFromVec(waveform_remainder_);
appended_wave.Range(waveform_remainder_.Dim(), resampled_wave.Dim())
.CopyFromVec(resampled_wave);
waveform_remainder_.Swap(&appended_wave);
}
input_finished_ = true;
ComputeFeatures();
}

template <class C>
void OnlineGenericBaseFeature<C>::AcceptWaveform(
BaseFloat sampling_rate, const VectorBase<BaseFloat> &original_waveform) {
if (original_waveform.Dim() == 0)
return; // Nothing to do.
if (input_finished_)
KALDI_ERR << "AcceptWaveform called after InputFinished() was called.";
// append 'waveform' to 'waveform_remainder_.'
Vector<BaseFloat> appended_wave(waveform_remainder_.Dim() + waveform.Dim());

Vector<BaseFloat> appended_wave;
Vector<BaseFloat> resampled_wave;

const VectorBase<BaseFloat> *waveform;

MaybeCreateResampler(sampling_rate);
if (resampler_ == nullptr) {
waveform = &original_waveform;
} else {
resampler_->Resample(original_waveform, false, &resampled_wave);
waveform = &resampled_wave;
}

appended_wave.Resize(waveform_remainder_.Dim() + waveform->Dim());
if (waveform_remainder_.Dim() != 0)
appended_wave.Range(0, waveform_remainder_.Dim()).CopyFromVec(
waveform_remainder_);
appended_wave.Range(waveform_remainder_.Dim(), waveform.Dim()).CopyFromVec(
waveform);
appended_wave.Range(0, waveform_remainder_.Dim())
.CopyFromVec(waveform_remainder_);
appended_wave.Range(waveform_remainder_.Dim(), waveform->Dim())
.CopyFromVec(*waveform);
waveform_remainder_.Swap(&appended_wave);
ComputeFeatures();
}

template<class C>
template <class C>
void OnlineGenericBaseFeature<C>::ComputeFeatures() {
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
int64 num_samples_total = waveform_offset_ + waveform_remainder_.Dim();
Expand Down Expand Up @@ -145,7 +196,6 @@ template class OnlineGenericBaseFeature<MfccComputer>;
template class OnlineGenericBaseFeature<PlpComputer>;
template class OnlineGenericBaseFeature<FbankComputer>;


OnlineCmvnState::OnlineCmvnState(const OnlineCmvnState &other):
speaker_cmvn_stats(other.speaker_cmvn_stats),
global_cmvn_stats(other.global_cmvn_stats),
Expand Down Expand Up @@ -173,8 +223,6 @@ void OnlineCmvnState::Read(std::istream &is, bool binary) {
ExpectToken(is, binary, "</OnlineCmvnState>");
}



OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts,
const OnlineCmvnState &cmvn_state,
OnlineFeatureInterface *src):
Expand Down Expand Up @@ -328,7 +376,8 @@ void OnlineCmvn::SmoothOnlineCmvnStats(const MatrixBase<double> &speaker_stats,
// If count exceeded cmn_window it would be an error in how "window_stats"
// was accumulated.
KALDI_ASSERT(cur_count <= 1.001 * opts.cmn_window);
if (cur_count >= opts.cmn_window) return;
if (cur_count >= opts.cmn_window)
return;
if (speaker_stats.NumRows() != 0) { // if we have speaker stats..
double count_from_speaker = opts.cmn_window - cur_count,
speaker_count = speaker_stats(0, dim);
Expand All @@ -341,7 +390,8 @@ void OnlineCmvn::SmoothOnlineCmvnStats(const MatrixBase<double> &speaker_stats,
speaker_stats);
cur_count = (*stats)(0, dim);
}
if (cur_count >= opts.cmn_window) return;
if (cur_count >= opts.cmn_window)
return;
if (global_stats.NumRows() != 0) {
double count_from_global = opts.cmn_window - cur_count,
global_count = global_stats(0, dim);
Expand Down Expand Up @@ -433,7 +483,7 @@ void OnlineCmvn::SetState(const OnlineCmvnState &cmvn_state) {

int32 OnlineSpliceFrames::NumFramesReady() const {
int32 num_frames = src_->NumFramesReady();
if (num_frames > 0 && src_->IsLastFrame(num_frames-1))
if (num_frames > 0 && src_->IsLastFrame(num_frames - 1))
return num_frames;
else
return std::max<int32>(0, num_frames - right_context_);
Expand Down
11 changes: 7 additions & 4 deletions src/feat/online-feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,7 @@ class OnlineGenericBaseFeature: public OnlineBaseFeature {
// more waveform. This will help flush out the last frame or two
// of features, in the case where snip-edges == false; it also
// affects the return value of IsLastFrame().
virtual void InputFinished() {
input_finished_ = true;
ComputeFeatures();
}
virtual void InputFinished();

private:
// This function computes any additional feature frames that it is possible to
Expand All @@ -127,8 +124,14 @@ class OnlineGenericBaseFeature: public OnlineBaseFeature {
// waveform_remainder_ while incrementing waveform_offset_ by the same amount.
void ComputeFeatures();

void MaybeCreateResampler(BaseFloat sampling_rate);

C computer_; // class that does the MFCC or PLP or filterbank computation

// resampler in cases when the input sampling frequency is not equal to
// the expected sampling rate
std::unique_ptr<LinearResample> resampler_;

FeatureWindowFunction window_function_;

// features_ is the Mfcc or Plp or Fbank features that we have already computed.
Expand Down
4 changes: 4 additions & 0 deletions src/feat/resample.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ class LinearResample {
/// Resample(x, y, true) for the last piece. Call it unnecessarily between
/// signals will not do any harm.
void Reset();

//// Return the input and output sampling rates (for checks, for example)
inline int32 GetInputSamplingRate() {return samp_rate_in_;};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add spaces in side the {}? I think the final semicolon is not required.

inline int32 GetOutputSamplingRate() {return samp_rate_out_;};
private:
/// This function outputs the number of output samples we will output
/// for a signal with "input_num_samp" input samples. If flush == true,
Expand Down