forked from NVIDIA/OpenSeq2Seq
-
Notifications
You must be signed in to change notification settings - Fork 0
/
doItAll.py
74 lines (58 loc) · 2.62 KB
/
doItAll.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from webTranscriptToTXT import webTranscriptToTXT
from infer2text import infer2text
from splitAndConvertMP3 import splitAndConvertMP3
from makeCSV import makeCSV
import sys
from subprocess import call
import os
from time import time
from compareWER import quickWER
def doItAll(dir):
t0 = time()
# Convert from MP3 to WAV, change channels to 1
# change sample-rate to 16000, and split every 5 seconds
call(['mkdir', dir + '/wavs'])
print('\n\n ** Splitting and converting MP3')
splitAndConvertMP3(dir + '/original.mp3')
print('\n\n ** Reading base config file')
# Read Base Config File
with open('config.py', 'r') as config_file:
config = config_file.read()
print('\n\n ** Creating model_input.csv')
# Create the corresponding CSV file, model_input.csv
makeCSV(dir + '/wavs')
print('\n\n ** Creating config file')
# Create the config file for this particular model
config = config.replace("# insert path to csv here", "\"" + dir + "/wavs/model_input.csv\"")
with open(dir + "/config.py", 'w') as config_file:
config_file.write(config)
print('\n\n ** Calling model from terminal')
# Call the model from terminal
# call_args = ['python', 'run.py', '--config_file=' + dir + "/config.py", '--mode=infer', '--infer_output_file= model_output.pickle']
call_args = ['python', 'run.py', '--config_file=' + dir + "/config.py", '--mode=infer', '--infer_output_file=' + dir + '/model_output.txt']
call(call_args)
infer2text(dir + '/model_output.txt')
try:
webTranscriptToTXT(dir + '/original.txt')
print('Truth/Model W/ Spellcheck: ' + str(quickWER(dir + '/converted_transcript.txt', dir + '/prediction_with_spellcheck.txt')))
print('Truth/Model W/O Spellcheck: ' + str(quickWER(dir + '/converted_transcript.txt', dir + '/prediction_no_spellcheck.txt')))
except:
pass
try:
print('Truth/Podscribe: ' + str(quickWER(dir + '/converted_transcript.txt', dir + '/podscribe.txt')))
print('Podscribe/Model: ' + str(quickWER(dir + '/podscribe.txt', dir + '/prediction_with_spellcheck.txt')))
except:
pass
# # If web transcript is provided, parse it from the NPR style and compare the WER
# if os.path.isfile(dir + '/original.txt'):
# print('\n\n Calculating WER')
# original = webTranscriptToTXT(dir + '/original.txt')
# predicted = infer2text(dir + '/model_output.txt')
# compareWer()
#
# else:
# print('No web transcript provided')
# infer2text(dir + '/model_output.txt')
print(time()-t0)
if __name__ == '__main__':
doItAll(sys.argv[1])