Skip to content

Commit

Permalink
[scripts,egs] Changes for Python 2/3 compatibility (kaldi-asr#2925)
Browse files Browse the repository at this point in the history
  • Loading branch information
desh2608 authored and danpovey committed Dec 31, 2018
1 parent 3e77220 commit 5a720ac
Show file tree
Hide file tree
Showing 178 changed files with 982 additions and 852 deletions.
1 change: 1 addition & 0 deletions egs/aishell2/s5/local/word_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
# Apache 2.0

from __future__ import print_function
import sys
import jieba
reload(sys)
Expand Down
5 changes: 3 additions & 2 deletions egs/ami/s5/local/sort_bad_utts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python

from __future__ import print_function
import sys
import argparse
import logging
Expand Down Expand Up @@ -38,10 +39,10 @@ def GetSortedWers(utt_info_file):
utt_wer_sorted = sorted(utt_wer, key = lambda k : k[1])
try:
import numpy as np
bins = range(0,105,5)
bins = list(range(0,105,5))
bins.append(sys.float_info.max)

hist, bin_edges = np.histogram(map(lambda x: x[1], utt_wer_sorted),
hist, bin_edges = np.histogram([x[1] for x in utt_wer_sorted],
bins = bins)
num_utts = len(utt_wer)
string = ''
Expand Down
1 change: 1 addition & 0 deletions egs/an4/s5/local/data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import os
import re
import sys
Expand Down
1 change: 1 addition & 0 deletions egs/an4/s5/local/lexicon_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import os
import re
import sys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# creates a segments file in the provided data directory
# into uniform segments with specified window and overlap

from __future__ import division
import imp, sys, argparse, os, math, subprocess

min_segment_length = 10 # in seconds
def segment(total_length, window_length, overlap = 0):
increment = window_length - overlap
num_windows = int(math.ceil(float(total_length)/increment))
segments = map(lambda x: (x * increment, min( total_length, (x * increment) + window_length)), range(0, num_windows))
segments = [(x * increment, min( total_length, (x * increment) + window_length)) for x in range(0, num_windows)]
if segments[-1][1] - segments[-1][0] < min_segment_length:
segments[-2] = (segments[-2][0], segments[-1][1])
segments.pop()
Expand Down Expand Up @@ -53,7 +54,7 @@ def prepare_segments_file(kaldi_data_dir, window_length, overlap):
parser = argparse.ArgumentParser()
parser.add_argument('--window-length', type = float, default = 30.0, help = 'length of the window used to cut the segment')
parser.add_argument('--overlap', type = float, default = 5.0, help = 'overlap of neighboring windows')
parser.add_argument('data_dir', type=str, help='directory such as data/train')
parser.add_argument('data_dir', help='directory such as data/train')

params = parser.parse_args()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ def fill_ctm(input_ctm_file, output_ctm_file, recording_names):

sys.stderr.write(str(" ".join(sys.argv)))
parser = argparse.ArgumentParser(usage)
parser.add_argument('input_ctm_file', type=str, help='ctm file for the recordings')
parser.add_argument('output_ctm_file', type=str, help='ctm file for the recordings')
parser.add_argument('recording_name_file', type=str, help='file with names of the recordings')
parser.add_argument('input_ctm_file', help='ctm file for the recordings')
parser.add_argument('output_ctm_file', help='ctm file for the recordings')
parser.add_argument('recording_name_file', help='file with names of the recordings')

params = parser.parse_args()

try:
file_names = map(lambda x: x.strip(), open("{0}".format(params.recording_name_file)).readlines())
file_names = [x.strip() for x in open("{0}".format(params.recording_name_file)).readlines()]
except IOError:
raise Exception("Expected to find {0}".format(params.recording_name_file))

Expand Down
3 changes: 2 additions & 1 deletion egs/aspire/s5/local/multi_condition/get_air_file_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# script to generate the file_patterns of the AIR database
# see load_air.m file in AIR db to understand the naming convention
from __future__ import print_function
import sys, glob, re, os.path

air_dir = sys.argv[1]
Expand Down Expand Up @@ -45,4 +46,4 @@
file_patterns.append(file_pattern+" "+output_file_name)
file_patterns = list(set(file_patterns))
file_patterns.sort()
print "\n".join(file_patterns)
print("\n".join(file_patterns))
8 changes: 5 additions & 3 deletions egs/aspire/s5/local/multi_condition/normalize_wavs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

# normalizes the wave files provided in input file list with a common scaling factor
# the common scaling factor is computed to 1/\sqrt(1/(total_samples) * \sum_i{\sum_j x_i(j)^2}) where total_samples is sum of all samples of all wavefiles. If the data is multi-channel then each channel is treated as a seperate wave files
from __future__ import division
from __future__ import print_function
import argparse, scipy.io.wavfile, warnings, numpy as np, math

def get_normalization_coefficient(file_list, is_rir, additional_scaling):
Expand All @@ -29,7 +31,7 @@ def get_normalization_coefficient(file_list, is_rir, additional_scaling):
assert(rate == sampling_rate)
else:
sampling_rate = rate
data = data / dtype_max_value
data = data/dtype_max_value
if is_rir:
# just count the energy of the direct impulse response
# this is treated as energy of signal from 0.001 seconds before impulse
Expand All @@ -55,8 +57,8 @@ def get_normalization_coefficient(file_list, is_rir, additional_scaling):
except IOError:
warnings.warn("Did not find the file {0}.".format(file))
assert(total_samples > 0)
scaling_coefficient = np.sqrt(total_samples / total_energy)
print "Scaling coefficient is {0}.".format(scaling_coefficient)
scaling_coefficient = np.sqrt(total_samples/total_energy)
print("Scaling coefficient is {0}.".format(scaling_coefficient))
if math.isnan(scaling_coefficient):
raise Exception(" Nan encountered while computing scaling coefficient. This is mostly due to numerical overflow")
return scaling_coefficient
Expand Down
6 changes: 3 additions & 3 deletions egs/aspire/s5/local/multi_condition/read_rir.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def usage():
#sys.stderr.write(" ".join(sys.argv)+"\n")
parser = argparse.ArgumentParser(usage())
parser.add_argument('--output-sampling-rate', type = int, default = 8000, help = 'sampling rate of the output')
parser.add_argument('type', type = str, default = None, help = 'database type', choices = ['air'])
parser.add_argument('input', type = str, default = None, help = 'directory containing the multi-channel data for a particular recording, or file name or file-regex-pattern')
parser.add_argument('output_filename', type = str, default = None, help = 'output filename (if "-" then output is written to output pipe)')
parser.add_argument('type', default = None, help = 'database type', choices = ['air'])
parser.add_argument('input', default = None, help = 'directory containing the multi-channel data for a particular recording, or file name or file-regex-pattern')
parser.add_argument('output_filename', default = None, help = 'output filename (if "-" then output is written to output pipe)')
params = parser.parse_args()

if params.output_filename == "-":
Expand Down
14 changes: 8 additions & 6 deletions egs/aspire/s5/local/multi_condition/reverberate_wavs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
# script to generate multicondition training data / dev data / test data
import argparse, glob, math, os, random, scipy.io.wavfile, sys

class list_cyclic_iterator:
class list_cyclic_iterator(object):
def __init__(self, list, random_seed = 0):
self.list_index = 0
self.list = list
random.seed(random_seed)
random.shuffle(self.list)

def next(self):
def __next__(self):
item = self.list[self.list_index]
self.list_index = (self.list_index + 1) % len(self.list)
return item

next = __next__ # for Python 2

def return_nonempty_lines(lines):
new_lines = []
for line in lines:
Expand Down Expand Up @@ -71,15 +73,15 @@ def return_nonempty_lines(lines):
for i in range(len(wav_files)):
wav_file = " ".join(wav_files[i].split()[1:])
output_wav_file = wav_out_files[i]
impulse_file = impulses.next()
impulse_file = next(impulses)
noise_file = ''
snr = ''
found_impulse = False
if add_noise:
for i in xrange(len(impulse_noise_index)):
for i in range(len(impulse_noise_index)):
if impulse_file in impulse_noise_index[i][0]:
noise_file = impulse_noise_index[i][1].next()
snr = snrs.next()
noise_file = next(impulse_noise_index[i][1])
snr = next(snrs)
assert(len(wav_file.strip()) > 0)
assert(len(impulse_file.strip()) > 0)
assert(len(noise_file.strip()) > 0)
Expand Down
19 changes: 10 additions & 9 deletions egs/babel/s5b/local/lonestar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
from __future__ import print_function
from pylauncher import *
import pylauncher
import sys
Expand Down Expand Up @@ -39,7 +40,7 @@ def KaldiLauncher(lo, **kwargs):

logfiles = list()
commands = list()
for q in xrange(lo.jobstart, lo.jobend+1):
for q in range(lo.jobstart, lo.jobend+1):
s = "bash " + lo.queue_scriptfile + " " + str(q)
commands.append(s)

Expand Down Expand Up @@ -74,7 +75,7 @@ def KaldiLauncher(lo, **kwargs):
time.sleep(delay);

lines=tail(10, logfile)
with_status=filter(lambda x:re.search(r'with status (\d+)', x), lines)
with_status=[x for x in lines if re.search(r'with status (\d+)', x)]

if len(with_status) == 0:
sys.stderr.write("The last line(s) of the log-file " + logfile + " does not seem"
Expand All @@ -98,7 +99,7 @@ def KaldiLauncher(lo, **kwargs):
sys.exit(-1);

#Remove service files. Be careful not to remove something that might be needed in problem diagnostics
for i in xrange(len(commands)):
for i in range(len(commands)):
out_file=os.path.join(qdir, ce.outstring+str(i))

#First, let's wait on files missing (it might be that those are missing
Expand Down Expand Up @@ -149,7 +150,7 @@ def KaldiLauncher(lo, **kwargs):

#print job.final_report()

class LauncherOpts:
class LauncherOpts(object):
def __init__(self):
self.sync=0
self.nof_threads = 1
Expand Down Expand Up @@ -199,7 +200,7 @@ def CmdLineParser(argv):
jobend=int(m.group(2))
argv.pop(0)
elif re.match("^.+=.*:.*$", argv[0]):
print >> sys.stderr, "warning: suspicious JOB argument " + argv[0];
print("warning: suspicious JOB argument " + argv[0], file=sys.stderr);

if jobstart > jobend:
sys.stderr.write("lonestar.py: JOBSTART("+ str(jobstart) + ") must be lower than JOBEND(" + str(jobend) + ")\n")
Expand Down Expand Up @@ -238,8 +239,8 @@ def setup_paths_and_vars(opts):
cwd = os.getcwd()

if opts.varname and (opts.varname not in opts.logfile ) and (opts.jobstart != opts.jobend):
print >>sys.stderr, "lonestar.py: you are trying to run a parallel job" \
"but you are putting the output into just one log file (" + opts.logfile + ")";
print("lonestar.py: you are trying to run a parallel job" \
"but you are putting the output into just one log file (" + opts.logfile + ")", file=sys.stderr);
sys.exit(1)

if not os.path.isabs(opts.logfile):
Expand All @@ -261,8 +262,8 @@ def setup_paths_and_vars(opts):
taskname=os.path.basename(queue_logfile)
taskname = taskname.replace(".log", "");
if taskname == "":
print >> sys.stderr, "lonestar.py: you specified the log file name in such form " \
"that leads to an empty task name ("+logfile + ")";
print("lonestar.py: you specified the log file name in such form " \
"that leads to an empty task name ("+logfile + ")", file=sys.stderr);
sys.exit(1)

if not os.path.isabs(queue_logfile):
Expand Down
29 changes: 15 additions & 14 deletions egs/babel/s5b/local/resegment/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright 2014 Vimal Manohar
# Apache 2.0

from __future__ import division
import os, glob, argparse, sys, re, time
from argparse import ArgumentParser

Expand All @@ -19,12 +20,12 @@

def mean(l):
if len(l) > 0:
return float(sum(l)) / len(l)
return (float(sum(l))/len(l))
return 0

# Analysis class
# Stores statistics like the confusion matrix, length of the segments etc.
class Analysis:
class Analysis(object):
def __init__(self, file_id, frame_shift, prefix):
self.confusion_matrix = [0] * 9
self.type_counts = [ [[] for j in range(0,9)] for i in range(0,3) ]
Expand Down Expand Up @@ -274,8 +275,8 @@ def read_rttm_file(rttm_file, temp_dir, frame_shift):
i = len(this_file)
category = splits[6]
word = splits[5]
start_time = int(float(splits[3])/frame_shift + 0.5)
duration = int(float(splits[4])/frame_shift + 0.5)
start_time = int((float(splits[3])/frame_shift) + 0.5)
duration = int((float(splits[4])/frame_shift) + 0.5)
if i < start_time:
this_file.extend(["0"]*(start_time - i))
if type1 == "NON-LEX":
Expand All @@ -295,7 +296,7 @@ def read_rttm_file(rttm_file, temp_dir, frame_shift):
# Stats class to store some basic stats about the number of
# times the post-processor goes through particular loops or blocks
# of code in the algorithm. This is just for debugging.
class Stats:
class Stats(object):
def __init__(self):
self.inter_utt_nonspeech = 0
self.merge_nonspeech_segment = 0
Expand All @@ -321,7 +322,7 @@ def reset(self):
self.noise_only = 0

# Timer class to time functions
class Timer:
class Timer(object):
def __enter__(self):
self.start = time.clock()
return self
Expand All @@ -332,7 +333,7 @@ def __exit__(self, *args):
# The main class for post-processing a file.
# This does the segmentation either looking at the file isolated
# or by looking at both classes simultaneously
class JointResegmenter:
class JointResegmenter(object):
def __init__(self, P, A, f, options, phone_map, stats = None, reference = None):

# Pointers to prediction arrays and Initialization
Expand Down Expand Up @@ -1290,22 +1291,22 @@ def main():
dest='hard_max_segment_length', default=15.0, \
help="Hard maximum on the segment length above which the segment " \
+ "will be broken even if in the middle of speech (default: %(default)s)")
parser.add_argument('--first-separator', type=str, \
parser.add_argument('--first-separator', \
dest='first_separator', default="-", \
help="Separator between recording-id and start-time (default: %(default)s)")
parser.add_argument('--second-separator', type=str, \
parser.add_argument('--second-separator', \
dest='second_separator', default="-", \
help="Separator between start-time and end-time (default: %(default)s)")
parser.add_argument('--remove-noise-only-segments', type=str, \
parser.add_argument('--remove-noise-only-segments', \
dest='remove_noise_only_segments', default="true", choices=("true", "false"), \
help="Remove segments that have only noise. (default: %(default)s)")
parser.add_argument('--min-inter-utt-silence-length', type=float, \
dest='min_inter_utt_silence_length', default=1.0, \
help="Minimum silence that must exist between two separate utterances (default: %(default)s)");
parser.add_argument('--channel1-file', type=str, \
parser.add_argument('--channel1-file', \
dest='channel1_file', default="inLine", \
help="String that matches with the channel 1 file (default: %(default)s)")
parser.add_argument('--channel2-file', type=str, \
parser.add_argument('--channel2-file', \
dest='channel2_file', default="outLine", \
help="String that matches with the channel 2 file (default: %(default)s)")
parser.add_argument('--isolated-resegmentation', \
Expand Down Expand Up @@ -1388,7 +1389,7 @@ def main():

speech_cap = None
if options.speech_cap_length != None:
speech_cap = int( options.speech_cap_length / options.frame_shift )
speech_cap = int(options.speech_cap_length/options.frame_shift)
# End if

for f in pred_files:
Expand Down Expand Up @@ -1454,7 +1455,7 @@ def main():
f2 = f3
# End if

if (len(A1) - len(A2)) > options.max_length_diff / options.frame_shift:
if (len(A1) - len(A2)) > int(options.max_length_diff/options.frame_shift):
sys.stderr.write( \
"%s: Warning: Lengths of %s and %s differ by more than %f. " \
% (sys.argv[0], f1,f2, options.max_length_diff) \
Expand Down
Loading

0 comments on commit 5a720ac

Please sign in to comment.