Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* Bug fixes

* Fix dictionary setup from previous runs

* Fix subset issue with multiple dictionaries

* Fix bug in silence probability calculation
  • Loading branch information
mmcauliffe committed May 5, 2022
1 parent 5344d52 commit 4f4283c
Show file tree
Hide file tree
Showing 15 changed files with 346 additions and 230 deletions.
9 changes: 9 additions & 0 deletions docs/source/changelog/changelog_2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@
Release candidates
==================

2.0.0rc7
--------

- Fixed a bug where silence correction was not being calculated correctly
- Fixed a bug where sample rate could not be specified when not using multiprocessing :github_pr:`444`
- Fixed an incompatibility with the Kaldi version 1016 where BLAS libraries were not operating in single-threaded mode
- Further optimized large multispeaker dictionary loading
- Fixed a bug where subsets were not properly generated when multiple dictionaries were used

2.0.0rc6
--------

Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/helper/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
WordData
WordType
PhoneType
DatabaseImportData
PronunciationProbabilityCounter
CtmInterval -- Data class for representing intervals in Kaldi's CTM files
5 changes: 1 addition & 4 deletions montreal_forced_aligner/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def __init__(
):
super().__init__(**kwargs)
self._db_engine = None
self._session = None

def initialize_database(self) -> None:
"""
Expand Down Expand Up @@ -240,9 +239,7 @@ def session(self, **kwargs) -> Session:
SqlAlchemy session
"""
autoflush = kwargs.pop("autoflush", False)
if self._session is None:
self._session = sqlalchemy.orm.Session(self.db_engine, autoflush=autoflush, **kwargs)
return self._session
return sqlalchemy.orm.Session(self.db_engine, autoflush=autoflush, **kwargs)


class MfaWorker(metaclass=abc.ABCMeta):
Expand Down
20 changes: 12 additions & 8 deletions montreal_forced_aligner/alignment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def format_correction(correction_value: float) -> float:

begin = time.time()
dictionary_counters = {
dict_id: PronunciationProbabilityCounter() for dict_id in self.dictionary_lookup.keys()
dict_id: PronunciationProbabilityCounter()
for dict_id in self.dictionary_lookup.values()
}
self.log_info("Generating pronunciations...")
arguments = self.generate_pronunciations_arguments()
Expand Down Expand Up @@ -295,17 +296,21 @@ def format_correction(correction_value: float) -> float:
silence_prob = silence_probabilities[w_p1]
bar_count_silence_wp[w_p2] += counts["silence"] * silence_prob
bar_count_non_silence_wp[w_p2] += counts["non_silence"] * (1 - silence_prob)
for w_p, silence_count in counter.silence_before_counts.items():
if w_p[0] in {initial_key[0], final_key[0], self.silence_word}:
for w, p, _ in pronunciations:
silence_count = counter.silence_before_counts[(w, p)]
if w in {initial_key[0], final_key[0], self.silence_word}:
continue
non_silence_count = counter.non_silence_before_counts[w_p]
non_silence_count = counter.non_silence_before_counts[(w, p)]
pron_mapping[(w, p)]["silence_before_correction"] = format_correction(
(silence_count + lambda_3) / (bar_count_silence_wp[w_p] + lambda_3)
(silence_count + lambda_3) / (bar_count_silence_wp[(w, p)] + lambda_3)
)

pron_mapping[(w, p)]["non_silence_before_correction"] = format_correction(
(non_silence_count + lambda_3) / (bar_count_non_silence_wp[w_p] + lambda_3)
(non_silence_count + lambda_3)
/ (bar_count_non_silence_wp[(w, p)] + lambda_3)
)
session.bulk_update_mappings(Pronunciation, pron_mapping.values())
session.flush()
initial_silence_count = counter.silence_before_counts[initial_key] + (
silence_probability * lambda_2
)
Expand Down Expand Up @@ -346,10 +351,9 @@ def format_correction(correction_value: float) -> float:
self.final_non_silence_correction = (
final_non_silence_correction_sum / self.num_dictionaries
)
session.bulk_update_mappings(Pronunciation, pron_mapping.values())
session.bulk_update_mappings(Dictionary, dictionary_mappings)
session.commit()
self.log_debug(f"Alignment round took {time.time() - begin}")
self.log_debug(f"Calculating pronunciation probabilities took {time.time() - begin}")

def _collect_alignments(self):
"""
Expand Down
24 changes: 14 additions & 10 deletions montreal_forced_aligner/corpus/acoustic_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
AcousticDirectoryParser,
CorpusProcessWorker,
)
from montreal_forced_aligner.data import DatabaseImportData
from montreal_forced_aligner.db import (
Corpus,
File,
Expand Down Expand Up @@ -91,7 +92,7 @@ class AcousticCorpusMixin(CorpusMixin, FeatureConfigMixin, metaclass=ABCMeta):
def __init__(self, audio_directory: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.audio_directory = audio_directory
self.sound_file_errors = {}
self.sound_file_errors = []
self.transcriptions_without_wavs = []
self.no_transcription_files = []
self.stopped = Stopped()
Expand Down Expand Up @@ -503,8 +504,11 @@ def mfcc(self) -> None:
else:
break
continue
if isinstance(result, KaldiProcessingError):
error_dict[result.job_name] = result
if isinstance(result, Exception):
key = "error"
if isinstance(result, KaldiProcessingError):
key = result.job_name
error_dict[key] = result
continue
pbar.update(result)
for p in procs:
Expand Down Expand Up @@ -799,6 +803,7 @@ def _load_corpus_from_source_mp(self) -> None:
procs.append(p)
p.start()
last_poll = time.time() - 30
import_data = DatabaseImportData()
try:
with self.session() as session:
with tqdm.tqdm(total=100, disable=getattr(self, "quiet", False)) as pbar:
Expand Down Expand Up @@ -828,14 +833,14 @@ def _load_corpus_from_source_mp(self) -> None:
error_dict[error_type] = []
error_dict[error_type].append(error)
else:
self.add_file(file, session)
import_data.add_objects(self.generate_import_objects(file))

self.log_debug(f"Processing queue: {time.process_time() - begin_time}")

if "error" in error_dict:
session.rollback()
raise error_dict["error"][1]
self._finalize_load(session)
self._finalize_load(session, import_data)
for k in ["sound_file_errors", "decode_error_files", "textgrid_read_errors"]:
if hasattr(self, k):
if k in error_dict:
Expand Down Expand Up @@ -921,6 +926,7 @@ def _load_corpus_from_source(self) -> None:
all_sound_files.update(exts.other_audio_files)
all_sound_files.update(exts.wav_files)
self.log_debug(f"Walking through {self.corpus_directory}...")
import_data = DatabaseImportData()
with self.session() as session:
for root, _, files in os.walk(self.corpus_directory, followlinks=True):
exts = find_exts(files)
Expand Down Expand Up @@ -962,18 +968,16 @@ def _load_corpus_from_source(self) -> None:
relative_path,
self.speaker_characters,
sanitize_function,
self.sample_frequency
self.sample_frequency,
)

self.add_file(file, session)
import_data.add_objects(self.generate_import_objects(file))
except TextParseError as e:
self.decode_error_files.append(e)
except TextGridParseError as e:
self.textgrid_read_errors.append(e)
except SoundFileError as e:
self.sound_file_errors.append(e)
self._finalize_load(session)
session.commit()
self._finalize_load(session, import_data)
if self.decode_error_files or self.textgrid_read_errors:
self.log_info(
"There were some issues with files in the corpus. "
Expand Down
Loading

0 comments on commit 4f4283c

Please sign in to comment.