Skip to content

Commit

Permalink
Add more helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
reuben committed Oct 21, 2019
1 parent 2a080ad commit 6f45cdd
Showing 1 changed file with 38 additions and 20 deletions.
58 changes: 38 additions & 20 deletions dataframe_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pandas
import regex as re
import subprocess
import sys
import wave
Expand All @@ -23,9 +24,9 @@ def limit_repeated_samples(df, limit):
return df

# Collect samples with transcripts that repeat <= `limit` times in the dataset.
idx_le_limit = df['transcript'].apply(lambda x: counter[x] <= limit)
data_le_limit = df[idx_le_limit]
data_gt_limit = df[~idx_le_limit]
nb_of_repeats = df['transcript'].apply(lambda x: counter[x])
data_le_limit = df[nb_of_repeats <= limit]
data_gt_limit = df[nb_of_repeats > limit]

# Sentences that repeat > `limit` times in the dataset.
sentences_gt_limit = [s for s, count in counter.most_common() if count > limit]
Expand Down Expand Up @@ -75,15 +76,15 @@ def generate_unique_dev_test(df, dev_size, test_size):
# formats used in the wild for .wav files and TensorFlow only accepts the format
# generated by SoX, so if the header is in a differen format you'll need to
# transcode the files with SoX.
def invalid_header(file):
def is_invalid_header(file):
with open(file, 'rb') as fin:
audio_format, num_channels, sample_rate, bits_per_sample = struct.unpack('<xxxxxxxxxxxxxxxxxxxxHHIxxxxxxH', fin.read(36))
return audio_format != 1 or num_channels != 1 or sample_rate != 16000 or bits_per_sample != 16


# You can do this to transcode files with different header formats:
#
# invalid = df['wav_filename'].apply(invalid_header)
# invalid = df['wav_filename'].apply(is_invalid_header)
# transcode_files(df, invalid)
#
def transcode_files(df, idx_to_transcode):
Expand All @@ -93,16 +94,16 @@ def transcode_files(df, idx_to_transcode):


# In case of the following TensorFlow error, and you're ABSOLUTELY sure all
# files are PCM Mono 16000 Hz, 16-bit per sample:
# files are PCM Mono 16000 Hz, 16-bits per sample:
#
# tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
# (0) Invalid argument: Bad bytes per sample in WAV header: Expected 2 but got 4
# [[{{node DecodeWav}}]]
# [[tower_7/IteratorGetNext]]
#
# Use with: df['wav_filename'].apply(fix_header)
# Use with: df['wav_filename'].apply(fix_header_bytes_per_sample)
#
def fix_header(wav_filename):
def fix_header_bytes_per_sample(wav_filename):
with open(wav_filename, 'r+b') as fio:
header = bytearray(fio.read(44))
bytes_per_sample = struct.unpack_from('<H', header[32:34])
Expand All @@ -114,29 +115,35 @@ def fix_header(wav_filename):

# Remove files that have transcripts with characters outside of the alphabet
#
# alphabet = set('abcdef...')
# df, removed = remove_files_non_alphabetic(df, alphabet)
# alphabet = set('abcdef...')
# df, removed = remove_files_non_alphabetic(df, alphabet)
#
def remove_files_non_alphabetic(df, alphabet):
alphabetic = df['transcript'].apply(lambda x: set(x) <= alphabet)
return df[alphabetic], df[~alphabetic]


# Remove characters in transcripts that aren't in the Letter Unicode character
# class (Punctuation, math symbols, numbers, etc).
def remove_non_letters(df):
df['transcript'] = df['transcript'].apply(lambda x: re.sub(r'[^\p{Letter}]', '', x))


# Find corrupted files (header duration does not match file size). Example:
#
# invalid = df['wav_filename'].apply(compare_header_and_size)
# print('The following files are corrupted:')
# print(df[invalid].values)
# invalid = df['wav_filename'].apply(bad_header_for_filesize)
# print('The following files are corrupted:')
# print(df[invalid].values)
#
def compare_header_and_size(wav_filename):
def bad_header_for_filesize(wav_filename):
with wave.open(wav_filename, 'r') as fin:
header_fsize = (fin.getnframes() * fin.getnchannels() * fin.getsampwidth()) + 44
file_fsize = os.path.getsize(wav_filename)
return header_fsize != file_fsize


# Remove files that are too short for their transcript
def remove_not_enough_windows(df, sample_rate=16000, win_step_ms=20, utf8=False):
# Find files that are too short for their transcript
def find_not_enough_windows(df, sample_rate=16000, win_step_ms=20, utf8=False):
# Compute number of windows in each file
num_samples = (df['wav_filesize'] - 44) // 2
samples_per_window = int(sample_rate * (win_step_ms / 1000.))
Expand All @@ -148,16 +155,15 @@ def remove_not_enough_windows(df, sample_rate=16000, win_step_ms=20, utf8=False)
else:
str_len = df['transcript'].str.len()

enough_windows = num_windows >= str_len
return df[enough_windows], df[~enough_windows]
return num_windows >= str_len


# Compute ratio of duration to transcript len. Extreme values likely correspond
# to problematic samples (too short for transcript, or too long for transcript).
# Example of how to visualize the histogram of ratios:
#
# ratio = duration_to_transcript_len_ratio(df)
# ratio.hist()
# ratio = duration_to_transcript_len_ratio(df)
# ratio.hist()
#
def duration_to_transcript_len_ratio(df, sample_rate=16000, utf8=False):
duration = (df['wav_filesize'] - 44) / 2 / sample_rate
Expand All @@ -167,3 +173,15 @@ def duration_to_transcript_len_ratio(df, sample_rate=16000, utf8=False):
tr_len = df['transcript'].str.len()
return duration / tr_len


# Compute RMS power from a single 16-bit per sample WAVE file
def rms(x):
with wave.open(x) as fin:
samples = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
return np.sqrt(np.mean(samples**2))


# Calculate RMS power of all samples in DataFrame
def compute_rms(df):
return df['wav_filename'].apply(rms)

0 comments on commit 6f45cdd

Please sign in to comment.