diff --git a/docs/source/changelog/changelog_2.0.rst b/docs/source/changelog/changelog_2.0.rst index 56eb9407..653a8187 100644 --- a/docs/source/changelog/changelog_2.0.rst +++ b/docs/source/changelog/changelog_2.0.rst @@ -10,6 +10,15 @@ Release candidates ================== +2.0.0rc7 +-------- + +- Fixed a bug where silence correction was not being calculated correctly +- Fixed a bug where sample rate could not be specified when not using multiprocessing :github_pr:`444` +- Fixed an incompatibility with the Kaldi version 1016 where BLAS libraries were not operating in single-threaded mode +- Further optimized large multispeaker dictionary loading +- Fixed a bug where subsets were not properly generated when multiple dictionaries were used + 2.0.0rc6 -------- diff --git a/docs/source/reference/helper/data.rst b/docs/source/reference/helper/data.rst index 86bc69f2..404429f3 100644 --- a/docs/source/reference/helper/data.rst +++ b/docs/source/reference/helper/data.rst @@ -11,5 +11,6 @@ WordData WordType PhoneType + DatabaseImportData PronunciationProbabilityCounter CtmInterval -- Data class for representing intervals in Kaldi's CTM files diff --git a/montreal_forced_aligner/abc.py b/montreal_forced_aligner/abc.py index 236f3d7b..31d5276a 100644 --- a/montreal_forced_aligner/abc.py +++ b/montreal_forced_aligner/abc.py @@ -178,7 +178,6 @@ def __init__( ): super().__init__(**kwargs) self._db_engine = None - self._session = None def initialize_database(self) -> None: """ @@ -240,9 +239,7 @@ def session(self, **kwargs) -> Session: SqlAlchemy session """ autoflush = kwargs.pop("autoflush", False) - if self._session is None: - self._session = sqlalchemy.orm.Session(self.db_engine, autoflush=autoflush, **kwargs) - return self._session + return sqlalchemy.orm.Session(self.db_engine, autoflush=autoflush, **kwargs) class MfaWorker(metaclass=abc.ABCMeta): diff --git a/montreal_forced_aligner/alignment/base.py b/montreal_forced_aligner/alignment/base.py index 930baac5..184b825d 100644 --- a/montreal_forced_aligner/alignment/base.py +++ b/montreal_forced_aligner/alignment/base.py @@ -173,7 +173,8 @@ def format_correction(correction_value: float) -> float: begin = time.time() dictionary_counters = { - dict_id: PronunciationProbabilityCounter() for dict_id in self.dictionary_lookup.keys() + dict_id: PronunciationProbabilityCounter() + for dict_id in self.dictionary_lookup.values() } self.log_info("Generating pronunciations...") arguments = self.generate_pronunciations_arguments() @@ -295,17 +296,21 @@ def format_correction(correction_value: float) -> float: silence_prob = silence_probabilities[w_p1] bar_count_silence_wp[w_p2] += counts["silence"] * silence_prob bar_count_non_silence_wp[w_p2] += counts["non_silence"] * (1 - silence_prob) - for w_p, silence_count in counter.silence_before_counts.items(): - if w_p[0] in {initial_key[0], final_key[0], self.silence_word}: + for w, p, _ in pronunciations: + silence_count = counter.silence_before_counts[(w, p)] + if w in {initial_key[0], final_key[0], self.silence_word}: continue - non_silence_count = counter.non_silence_before_counts[w_p] + non_silence_count = counter.non_silence_before_counts[(w, p)] pron_mapping[(w, p)]["silence_before_correction"] = format_correction( - (silence_count + lambda_3) / (bar_count_silence_wp[w_p] + lambda_3) + (silence_count + lambda_3) / (bar_count_silence_wp[(w, p)] + lambda_3) ) pron_mapping[(w, p)]["non_silence_before_correction"] = format_correction( - (non_silence_count + lambda_3) / (bar_count_non_silence_wp[w_p] + lambda_3) + (non_silence_count + lambda_3) + / (bar_count_non_silence_wp[(w, p)] + lambda_3) ) + session.bulk_update_mappings(Pronunciation, pron_mapping.values()) + session.flush() initial_silence_count = counter.silence_before_counts[initial_key] + ( silence_probability * lambda_2 ) @@ -346,10 +351,9 @@ def format_correction(correction_value: float) -> float: self.final_non_silence_correction = ( final_non_silence_correction_sum / self.num_dictionaries ) - session.bulk_update_mappings(Pronunciation, pron_mapping.values()) session.bulk_update_mappings(Dictionary, dictionary_mappings) session.commit() - self.log_debug(f"Alignment round took {time.time() - begin}") + self.log_debug(f"Calculating pronunciation probabilities took {time.time() - begin}") def _collect_alignments(self): """ diff --git a/montreal_forced_aligner/corpus/acoustic_corpus.py b/montreal_forced_aligner/corpus/acoustic_corpus.py index e5f67ae4..51a883b3 100644 --- a/montreal_forced_aligner/corpus/acoustic_corpus.py +++ b/montreal_forced_aligner/corpus/acoustic_corpus.py @@ -32,6 +32,7 @@ AcousticDirectoryParser, CorpusProcessWorker, ) +from montreal_forced_aligner.data import DatabaseImportData from montreal_forced_aligner.db import ( Corpus, File, @@ -91,7 +92,7 @@ class AcousticCorpusMixin(CorpusMixin, FeatureConfigMixin, metaclass=ABCMeta): def __init__(self, audio_directory: Optional[str] = None, **kwargs): super().__init__(**kwargs) self.audio_directory = audio_directory - self.sound_file_errors = {} + self.sound_file_errors = [] self.transcriptions_without_wavs = [] self.no_transcription_files = [] self.stopped = Stopped() @@ -503,8 +504,11 @@ def mfcc(self) -> None: else: break continue - if isinstance(result, KaldiProcessingError): - error_dict[result.job_name] = result + if isinstance(result, Exception): + key = "error" + if isinstance(result, KaldiProcessingError): + key = result.job_name + error_dict[key] = result continue pbar.update(result) for p in procs: @@ -799,6 +803,7 @@ def _load_corpus_from_source_mp(self) -> None: procs.append(p) p.start() last_poll = time.time() - 30 + import_data = DatabaseImportData() try: with self.session() as session: with tqdm.tqdm(total=100, disable=getattr(self, "quiet", False)) as pbar: @@ -828,14 +833,14 @@ def _load_corpus_from_source_mp(self) -> None: error_dict[error_type] = [] error_dict[error_type].append(error) else: - self.add_file(file, session) + import_data.add_objects(self.generate_import_objects(file)) self.log_debug(f"Processing queue: {time.process_time() - begin_time}") if "error" in error_dict: session.rollback() raise error_dict["error"][1] - self._finalize_load(session) + self._finalize_load(session, import_data) for k in ["sound_file_errors", "decode_error_files", "textgrid_read_errors"]: if hasattr(self, k): if k in error_dict: @@ -921,6 +926,7 @@ def _load_corpus_from_source(self) -> None: all_sound_files.update(exts.other_audio_files) all_sound_files.update(exts.wav_files) self.log_debug(f"Walking through {self.corpus_directory}...") + import_data = DatabaseImportData() with self.session() as session: for root, _, files in os.walk(self.corpus_directory, followlinks=True): exts = find_exts(files) @@ -962,18 +968,16 @@ def _load_corpus_from_source(self) -> None: relative_path, self.speaker_characters, sanitize_function, - self.sample_frequency + self.sample_frequency, ) - - self.add_file(file, session) + import_data.add_objects(self.generate_import_objects(file)) except TextParseError as e: self.decode_error_files.append(e) except TextGridParseError as e: self.textgrid_read_errors.append(e) except SoundFileError as e: self.sound_file_errors.append(e) - self._finalize_load(session) - session.commit() + self._finalize_load(session, import_data) if self.decode_error_files or self.textgrid_read_errors: self.log_info( "There were some issues with files in the corpus. " diff --git a/montreal_forced_aligner/corpus/base.py b/montreal_forced_aligner/corpus/base.py index 0e240e65..ca971aa9 100644 --- a/montreal_forced_aligner/corpus/base.py +++ b/montreal_forced_aligner/corpus/base.py @@ -13,7 +13,7 @@ from montreal_forced_aligner.abc import DatabaseMixin, MfaWorker from montreal_forced_aligner.corpus.classes import FileData, UtteranceData from montreal_forced_aligner.corpus.multiprocessing import Job -from montreal_forced_aligner.data import TextFileType +from montreal_forced_aligner.data import DatabaseImportData, TextFileType from montreal_forced_aligner.db import ( Corpus, Dictionary, @@ -101,12 +101,6 @@ def __init__( self._current_speaker_index = 1 self._current_file_index = 1 self._speaker_ids = {} - self._speaker_objects = [] - self._file_objects = [] - self._text_file_objects = [] - self._sound_file_objects = [] - self._speaker_ordering_objects = [] - self._utterance_objects = [] def inspect_database(self) -> None: """Check if a database file exists and create the necessary metadata""" @@ -439,23 +433,26 @@ def initialize_jobs(self) -> None: for j in self.jobs: j.refresh_dictionaries(session) - def _finalize_load(self, session): + def _finalize_load(self, session: Session, import_data: DatabaseImportData): """Finalize the import of database objects after parsing""" with session.bind.begin() as conn: - if self._speaker_objects: - conn.execute(sqlalchemy.insert(Speaker.__table__), self._speaker_objects) - if self._file_objects: - conn.execute(sqlalchemy.insert(File.__table__), self._file_objects) - if self._text_file_objects: - conn.execute(sqlalchemy.insert(TextFile.__table__), self._text_file_objects) - if self._sound_file_objects: - conn.execute(sqlalchemy.insert(SoundFile.__table__), self._sound_file_objects) - if self._speaker_ordering_objects: + if import_data.speaker_objects: + conn.execute(sqlalchemy.insert(Speaker.__table__), import_data.speaker_objects) + if import_data.file_objects: + conn.execute(sqlalchemy.insert(File.__table__), import_data.file_objects) + if import_data.text_file_objects: + conn.execute(sqlalchemy.insert(TextFile.__table__), import_data.text_file_objects) + if import_data.sound_file_objects: conn.execute( - sqlalchemy.insert(SpeakerOrdering.__table__), self._speaker_ordering_objects + sqlalchemy.insert(SoundFile.__table__), import_data.sound_file_objects ) - if self._utterance_objects: - conn.execute(sqlalchemy.insert(Utterance.__table__), self._utterance_objects) + if import_data.speaker_ordering_objects: + conn.execute( + sqlalchemy.insert(SpeakerOrdering.__table__), + import_data.speaker_ordering_objects, + ) + if import_data.utterance_objects: + conn.execute(sqlalchemy.insert(Utterance.__table__), import_data.utterance_objects) session.commit() speakers = ( session.query(Speaker.id) @@ -463,17 +460,12 @@ def _finalize_load(self, session): .group_by(Speaker.id) .having(sqlalchemy.func.count(Utterance.id) == 0) ) + self._speaker_ids = {} speaker_ids = [x[0] for x in speakers] if speaker_ids: session.query(Speaker).filter(Speaker.id.in_(speaker_ids)).delete() session.commit() self._num_speakers = None - self._speaker_objects = [] - self._file_objects = [] - self._text_file_objects = [] - self._sound_file_objects = [] - self._speaker_ordering_objects = [] - self._utterance_objects = [] def add_speaker(self, name: str, session: Session = None): """ @@ -501,9 +493,10 @@ def add_speaker(self, name: str, session: Session = None): dictionary = session.query(Dictionary).get(self.get_dictionary_id(name)) speaker_obj = Speaker(name=name, dictionary=dictionary) session.add(speaker_obj) - self._speaker_ids[name] = speaker_obj + session.flush() + self._speaker_ids[name] = speaker_obj.id else: - self._speaker_ids[name] = speaker_obj + self._speaker_ids[name] = speaker_obj.id if close: session.commit() @@ -522,102 +515,53 @@ def add_file(self, file: FileData, session: Session = None): if session is None: session = self.session() close = True - if close: - f = File( - id=self._current_file_index, - name=file.name, - relative_path=file.relative_path, - modified=False, - ) - session.add(f) - else: - self._file_objects.append( - { - "id": self._current_file_index, - "name": file.name, - "relative_path": file.relative_path, - "modified": False, - } - ) + f = File( + id=self._current_file_index, + name=file.name, + relative_path=file.relative_path, + modified=False, + ) + session.add(f) + session.flush() for i, speaker in enumerate(file.speaker_ordering): if speaker not in self._speaker_ids: - if close: - speaker_obj = Speaker( - id=self._current_speaker_index, - name=speaker, - dictionary_id=getattr(self, "_default_dictionary_id", None), - ) - session.add(speaker_obj) - else: - self._speaker_objects.append( - { - "id": self._current_speaker_index, - "name": speaker, - "dictionary_id": getattr(self, "_default_dictionary_id", None), - } - ) + speaker_obj = Speaker( + id=self._current_speaker_index, + name=speaker, + dictionary_id=getattr(self, "_default_dictionary_id", None), + ) + session.add(speaker_obj) self._speaker_ids[speaker] = self._current_speaker_index self._current_speaker_index += 1 - if close: - so = SpeakerOrdering( - file_id=self._current_file_index, - speaker_id=self._speaker_ids[speaker], - index=i, - ) - session.add(so) - else: - self._speaker_ordering_objects.append( - { - "file_id": self._current_file_index, - "speaker_id": self._speaker_ids[speaker], - "index": i, - } - ) + so = SpeakerOrdering( + file_id=self._current_file_index, + speaker_id=self._speaker_ids[speaker], + index=i, + ) + session.add(so) if file.wav_path is not None: - if close: - sf = SoundFile( - file_id=self._current_file_index, - sound_file_path=file.wav_path, - format=file.wav_info.format, - sample_rate=file.wav_info.sample_rate, - duration=file.wav_info.duration, - num_channels=file.wav_info.num_channels, - sox_string=file.wav_info.sox_string, - ) - session.add(sf) - else: - self._sound_file_objects.append( - { - "file_id": self._current_file_index, - "sound_file_path": file.wav_path, - "format": file.wav_info.format, - "sample_rate": file.wav_info.sample_rate, - "duration": file.wav_info.duration, - "num_channels": file.wav_info.num_channels, - "sox_string": file.wav_info.sox_string, - } - ) + sf = SoundFile( + file_id=self._current_file_index, + sound_file_path=file.wav_path, + format=file.wav_info.format, + sample_rate=file.wav_info.sample_rate, + duration=file.wav_info.duration, + num_channels=file.wav_info.num_channels, + sox_string=file.wav_info.sox_string, + ) + session.add(sf) if file.text_path is not None: text_type = file.text_type if isinstance(text_type, TextFileType): text_type = file.text_type.value - if close: - tf = TextFile( - file_id=self._current_file_index, - text_file_path=file.text_path, - file_type=text_type, - ) - session.add(tf) - else: - self._text_file_objects.append( - { - "file_id": self._current_file_index, - "text_file_path": file.text_path, - "file_type": text_type, - } - ) + tf = TextFile( + file_id=self._current_file_index, + text_file_path=file.text_path, + file_type=text_type, + ) + session.add(tf) frame_shift = getattr(self, "frame_shift", None) if frame_shift is not None: frame_shift = round(frame_shift / 1000, 4) @@ -626,47 +570,118 @@ def add_file(self, file: FileData, session: Session = None): num_frames = None if frame_shift is not None: num_frames = int(duration / frame_shift) - if close: - utterance = Utterance( - begin=u.begin, - end=u.end, - duration=duration, - channel=u.channel, - oovs=u.oovs, - normalized_text=u.normalized_text, - text=u.text, - normalized_text_int=u.normalized_text_int, - num_frames=num_frames, - in_subset=False, - ignored=False, - file_id=self._current_file_index, - speaker_id=self._speaker_ids[u.speaker_name], - ) - session.add(utterance) - else: - self._utterance_objects.append( - { - "begin": u.begin, - "end": u.end, - "duration": duration, - "channel": u.channel, - "oovs": u.oovs, - "normalized_text": u.normalized_text, - "text": u.text, - "normalized_text_int": u.normalized_text_int, - "num_frames": num_frames, - "in_subset": False, - "ignored": False, - "file_id": self._current_file_index, - "speaker_id": self._speaker_ids[u.speaker_name], - } - ) + utterance = Utterance( + begin=u.begin, + end=u.end, + duration=duration, + channel=u.channel, + oovs=u.oovs, + normalized_text=u.normalized_text, + text=u.text, + normalized_text_int=u.normalized_text_int, + num_frames=num_frames, + in_subset=False, + ignored=False, + file_id=self._current_file_index, + speaker_id=self._speaker_ids[u.speaker_name], + ) + session.add(utterance) if close: session.commit() session.close() self._current_file_index += 1 + def generate_import_objects(self, file: FileData) -> DatabaseImportData: + """ + Add a file to the corpus + + Parameters + ---------- + file: :class:`~montreal_forced_aligner.corpus.classes.FileData` + File to be added + """ + data = DatabaseImportData() + data.file_objects.append( + { + "id": self._current_file_index, + "name": file.name, + "relative_path": file.relative_path, + "modified": False, + } + ) + for i, speaker in enumerate(file.speaker_ordering): + if speaker not in self._speaker_ids: + data.speaker_objects.append( + { + "id": self._current_speaker_index, + "name": speaker, + "dictionary_id": getattr(self, "_default_dictionary_id", None), + } + ) + self._speaker_ids[speaker] = self._current_speaker_index + self._current_speaker_index += 1 + + data.speaker_ordering_objects.append( + { + "file_id": self._current_file_index, + "speaker_id": self._speaker_ids[speaker], + "index": i, + } + ) + if file.wav_path is not None: + data.sound_file_objects.append( + { + "file_id": self._current_file_index, + "sound_file_path": file.wav_path, + "format": file.wav_info.format, + "sample_rate": file.wav_info.sample_rate, + "duration": file.wav_info.duration, + "num_channels": file.wav_info.num_channels, + "sox_string": file.wav_info.sox_string, + } + ) + if file.text_path is not None: + text_type = file.text_type + if isinstance(text_type, TextFileType): + text_type = file.text_type.value + + data.text_file_objects.append( + { + "file_id": self._current_file_index, + "text_file_path": file.text_path, + "file_type": text_type, + } + ) + frame_shift = getattr(self, "frame_shift", None) + if frame_shift is not None: + frame_shift = round(frame_shift / 1000, 4) + for u in file.utterances: + duration = u.end - u.begin + num_frames = None + if frame_shift is not None: + num_frames = int(duration / frame_shift) + data.utterance_objects.append( + { + "begin": u.begin, + "end": u.end, + "duration": duration, + "channel": u.channel, + "oovs": u.oovs, + "normalized_text": u.normalized_text, + "text": u.text, + "normalized_text_int": u.normalized_text_int, + "num_frames": num_frames, + "in_subset": False, + "ignored": False, + "file_id": self._current_file_index, + "speaker_id": self._speaker_ids[u.speaker_name], + } + ) + + self._current_file_index += 1 + return data + @property def data_source_identifier(self) -> str: """Corpus name""" @@ -691,7 +706,7 @@ def create_subset(self, subset: int) -> None: subsets_per_dictionary = {} utts_per_dictionary = {} subsetted = 0 - for dict_id in getattr(self, "dictionary_mapping", {}).keys(): + for dict_id in getattr(self, "dictionary_lookup", {}).values(): num_utts = ( session.query(Utterance) .join(Utterance.speaker) @@ -705,7 +720,7 @@ def create_subset(self, subset: int) -> None: remaining_subset = subset - sum(subsets_per_dictionary.values()) remaining_dicts = num_dictionaries - subsetted remaining_subset_per_dictionary = int(remaining_subset / remaining_dicts) - for dict_id in getattr(self, "dictionary_mapping", {}).keys(): + for dict_id in getattr(self, "dictionary_lookup", {}).values(): num_utts = utts_per_dictionary[dict_id] if dict_id in subsets_per_dictionary: subset_per_dictionary = subsets_per_dictionary[dict_id] @@ -811,7 +826,8 @@ def create_subset(self, subset: int) -> None: session.commit() - # Extra check to make sure the randomness didn't end up with 1 or 2 utterances for a particular job/dictionary combo + # Extra check to make sure the randomness didn't end up with 1 or 2 utterances + # for a particular job/dictionary combo subset_agg = ( session.query( Speaker.job_id, Speaker.dictionary_id, sqlalchemy.func.count(Utterance.id) diff --git a/montreal_forced_aligner/corpus/features.py b/montreal_forced_aligner/corpus/features.py index 7f57e025..1afda101 100644 --- a/montreal_forced_aligner/corpus/features.py +++ b/montreal_forced_aligner/corpus/features.py @@ -7,6 +7,8 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Union +import dataclassy + from montreal_forced_aligner.abc import KaldiFunction from montreal_forced_aligner.data import MfaArguments from montreal_forced_aligner.exceptions import KaldiProcessingError @@ -28,6 +30,8 @@ ] +# noinspection PyUnresolvedReferences +@dataclassy.dataclass(slots=True) class VadArguments(MfaArguments): """Arguments for :class:`~montreal_forced_aligner.corpus.features.ComputeVadFunction`""" @@ -36,6 +40,8 @@ class VadArguments(MfaArguments): vad_options: MetaDict +# noinspection PyUnresolvedReferences +@dataclassy.dataclass(slots=True) class MfccArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.corpus.features.MfccFunction` @@ -48,6 +54,8 @@ class MfccArguments(MfaArguments): pitch_options: MetaDict +# noinspection PyUnresolvedReferences +@dataclassy.dataclass(slots=True) class CalcFmllrArguments(MfaArguments): """Arguments for :class:`~montreal_forced_aligner.corpus.features.CalcFmllrFunction`""" @@ -61,6 +69,8 @@ class CalcFmllrArguments(MfaArguments): fmllr_options: MetaDict +# noinspection PyUnresolvedReferences +@dataclassy.dataclass(slots=True) class ExtractIvectorsArguments(MfaArguments): """Arguments for :class:`~montreal_forced_aligner.corpus.features.ExtractIvectorsFunction`""" diff --git a/montreal_forced_aligner/corpus/text_corpus.py b/montreal_forced_aligner/corpus/text_corpus.py index 25b320e5..e477d897 100644 --- a/montreal_forced_aligner/corpus/text_corpus.py +++ b/montreal_forced_aligner/corpus/text_corpus.py @@ -14,6 +14,7 @@ from montreal_forced_aligner.corpus.classes import FileData from montreal_forced_aligner.corpus.helper import find_exts from montreal_forced_aligner.corpus.multiprocessing import CorpusProcessWorker +from montreal_forced_aligner.data import DatabaseImportData from montreal_forced_aligner.dictionary.multispeaker import MultispeakerDictionaryMixin from montreal_forced_aligner.exceptions import TextGridParseError, TextParseError from montreal_forced_aligner.utils import Stopped @@ -58,11 +59,7 @@ def _load_corpus_from_source_mp(self) -> None: ) procs.append(p) p.start() - self._speaker_objects = [] - self._file_objects = [] - self._text_file_objects = [] - self._speaker_ordering_objects = [] - self._utterance_objects = [] + import_data = DatabaseImportData() try: file_count = 0 with tqdm.tqdm( @@ -118,7 +115,7 @@ def _load_corpus_from_source_mp(self) -> None: error_dict[error_type] = [] error_dict[error_type].append(error) else: - self.add_file(file, session) + import_data.add_objects(self.generate_import_objects(file)) self.log_debug("Waiting for workers to finish...") for p in procs: @@ -128,7 +125,7 @@ def _load_corpus_from_source_mp(self) -> None: session.rollback() raise error_dict["error"][1] - self._finalize_load(session) + self._finalize_load(session, import_data) for k in ["decode_error_files", "textgrid_read_errors"]: if hasattr(self, k): @@ -184,6 +181,7 @@ def _load_corpus_from_source(self) -> None: begin_time = time.time() self.stopped = False + import_data = DatabaseImportData() sanitize_function = getattr(self, "sanitize_function", None) with self.session() as session: for root, _, files in os.walk(self.corpus_directory, followlinks=True): @@ -211,12 +209,12 @@ def _load_corpus_from_source(self) -> None: self.speaker_characters, sanitize_function, ) - self.add_file(file, session) + import_data.add_objects(self.generate_import_objects(file)) except TextParseError as e: self.decode_error_files.append(e) except TextGridParseError as e: self.textgrid_read_errors.append(e) - self._finalize_load(session) + self._finalize_load(session, import_data) if self.decode_error_files or self.textgrid_read_errors: self.log_info( "There were some issues with files in the corpus. " diff --git a/montreal_forced_aligner/data.py b/montreal_forced_aligner/data.py index cdb7efc7..c2245ddd 100644 --- a/montreal_forced_aligner/data.py +++ b/montreal_forced_aligner/data.py @@ -16,11 +16,6 @@ from .exceptions import CtmError -if typing.TYPE_CHECKING: - from dataclasses import dataclass -else: - from dataclassy import dataclass - __all__ = [ "MfaArguments", "CtmInterval", @@ -30,11 +25,59 @@ "PhoneType", "PhoneSetType", "WordData", + "DatabaseImportData", "PronunciationProbabilityCounter", ] -@dataclass(slots=True) +# noinspection PyUnresolvedReferences +@dataclassy.dataclass(slots=True) +class DatabaseImportData: + """ + Class for storing information on importing data into the database + + Parameters + ---------- + speaker_objects: list[dict[str, Any]] + List of dictionaries with :class:`~montreal_forced_aligner.db.Speaker` properties + file_objects: list[dict[str, Any]] + List of dictionaries with :class:`~montreal_forced_aligner.db.File` properties + text_file_objects: list[dict[str, Any]] + List of dictionaries with :class:`~montreal_forced_aligner.db.TextFile` properties + sound_file_objects: list[dict[str, Any]] + List of dictionaries with :class:`~montreal_forced_aligner.db.SoundFile` properties + speaker_ordering_objects: list[dict[str, Any]] + List of dictionaries with :class:`~montreal_forced_aligner.db.SpeakerOrdering` properties + utterance_objects: list[dict[str, Any]] + List of dictionaries with :class:`~montreal_forced_aligner.db.Utterance` properties + """ + + speaker_objects: typing.List[typing.Dict[str, typing.Any]] = dataclassy.factory(list) + file_objects: typing.List[typing.Dict[str, typing.Any]] = dataclassy.factory(list) + text_file_objects: typing.List[typing.Dict[str, typing.Any]] = dataclassy.factory(list) + sound_file_objects: typing.List[typing.Dict[str, typing.Any]] = dataclassy.factory(list) + speaker_ordering_objects: typing.List[typing.Dict[str, typing.Any]] = dataclassy.factory(list) + utterance_objects: typing.List[typing.Dict[str, typing.Any]] = dataclassy.factory(list) + + def add_objects(self, other_import: DatabaseImportData) -> None: + """ + Combine objects for two importers + + Parameters + ---------- + other_import: :class:`~montreal_forced_aligner.data.DatabaseImportData` + Other object with objects to import + """ + self.speaker_objects.extend(other_import.speaker_objects) + self.file_objects.extend(other_import.file_objects) + self.text_file_objects.extend(other_import.text_file_objects) + self.sound_file_objects.extend(other_import.sound_file_objects) + self.speaker_ordering_objects.extend(other_import.speaker_ordering_objects) + self.utterance_objects.extend(other_import.utterance_objects) + + +# noinspection PyUnresolvedReferences +@dataclassy.dataclass(slots=True) class MfaArguments: """ Base class for argument classes for MFA functions @@ -160,7 +203,7 @@ def regex_detect(self) -> typing.Optional[re.Pattern]: return re.compile(r"[a-z]{1,3}[12345]") elif self is PhoneSetType.IPA: return re.compile( - r"[əɚʊɡɤʁɹɔɛʉɒβɲɟʝŋʃɕʰʲɾ̃̚ː˩˨˧˦˥̪̝̟̥̂̀̄ˑ̊ᵝ̠̹̞̩̯̬̺ˀˤ̻̙̘̰̤̜̹̑̽᷈᷄᷅̌̋̏‿̆͜͡ˌˈ̣]" + r"[əɚʊɡɤʁɹɔɛʉɒβɲɟʝŋʃɕʰʲɾ̃̚ː˩˨˧˦˥̪̝̟̥̂̀̄ˑ̊ᵝ̠̹̞̩̯̬̺ˀˤ̻̙̘̰̤̜̑̽᷈᷄᷅̌̋̏‿̆͜͡ˌˈ̣]" ) return None @@ -168,7 +211,7 @@ def regex_detect(self) -> typing.Optional[re.Pattern]: def suprasegmental_phone_regex(self) -> typing.Optional[re.Pattern]: """Regex for creating base phones""" if self is PhoneSetType.IPA: - return re.compile(r"([ː̟̥̂̀̄ˑ̊ᵝ̠̹̞̩̯̬̺ˤ̻̙̘̤̜̹̑̽᷈᷄᷅̌̋̏‿̆͜͡ˌ̍ʱʰʲ̚ʼ͈ˈ̣ᵝ]+)") + return re.compile(r"([ː̟̥̂̀̄ˑ̊ᵝ̠̹̞̩̯̬̺ˤ̻̙̘̤̜̑̽᷈᷄᷅̌̋̏‿̆͜͡ˌ̍ʱʰʲ̚ʼ͈ˈ̣]+)") return None @property @@ -179,7 +222,7 @@ def base_phone_regex(self) -> typing.Optional[re.Pattern]: elif self is PhoneSetType.PINYIN: return re.compile(r"[12345]") elif self is PhoneSetType.IPA: - return re.compile(r"([ː˩˨˧˦˥̟̥̂̀̄ˑ̊ᵝ̠̹̞̩̯̬̺ˀˤ̻̙̘̤̜̹̑̽᷈᷄᷅̌̋̏‿̆͜͡ˌ̍ˈ]+)") + return re.compile(r"([ː˩˨˧˦˥̟̥̂̀̄ˑ̊ᵝ̠̹̞̩̯̬̺ˀˤ̻̙̘̤̜̑̽᷈᷄᷅̌̋̏‿̆͜͡ˌ̍ˈ]+)") return None @property @@ -1019,7 +1062,8 @@ def add_consonant_variants(consonant_set): return extra_questions -@dataclass(slots=True) +# noinspection PyUnresolvedReferences +@dataclassy.dataclass(slots=True) class SoundFileInformation: """ Data class for sound file information with format, duration, number of channels, bit depth, and @@ -1051,7 +1095,8 @@ def meta(self) -> typing.Dict[str, typing.Any]: return dataclassy.asdict(self) -@dataclass(slots=True) +# noinspection PyUnresolvedReferences +@dataclassy.dataclass(slots=True) class FileExtensions: """ Data class for information about the current directory @@ -1077,7 +1122,8 @@ class FileExtensions: other_audio_files: typing.Dict[str, str] -@dataclass(slots=True) +# noinspection PyUnresolvedReferences +@dataclassy.dataclass(slots=True) class WordData: """ Data class for information about a word and its pronunciations @@ -1094,7 +1140,8 @@ class WordData: pronunciations: typing.Set[typing.Tuple[str, ...]] -@dataclass(slots=True) +# noinspection PyUnresolvedReferences +@dataclassy.dataclass(slots=True) class PronunciationProbabilityCounter: """ Data class for count information used in pronunciation probability modeling @@ -1151,7 +1198,8 @@ def add_counts(self, other_counter: PronunciationProbabilityCounter) -> None: self.non_silence_before_counts.update(other_counter.non_silence_before_counts) -@dataclass(slots=True) +# noinspection PyUnresolvedReferences +@dataclassy.dataclass(slots=True) class CtmInterval: """ Data class for intervals derived from CTM files @@ -1191,7 +1239,8 @@ def __post_init__(self): def to_tg_interval(self) -> Interval: """ - Converts the CTMInterval to `PraatIO's Interval class `_ + Converts the CTMInterval to + `PraatIO's Interval class `_ Returns ------- diff --git a/montreal_forced_aligner/dictionary/mixins.py b/montreal_forced_aligner/dictionary/mixins.py index d5cacf81..56736e17 100644 --- a/montreal_forced_aligner/dictionary/mixins.py +++ b/montreal_forced_aligner/dictionary/mixins.py @@ -132,6 +132,13 @@ def __init__( self.clitic_cleanup_regex = re.compile( rf'[{extra}{"".join(other_clitic_markers)}]' ) + non_word_character_set = sorted(self.all_punctuation) + if "-" in self.all_punctuation: + extra = "-" + non_word_character_set = [x for x in non_word_character_set if x != "-"] + self.punctuation_regex = re.compile( + rf"^[{extra}{re.escape(''.join(non_word_character_set))}]+$" + ) def __call__(self, text) -> typing.Generator[str]: """ @@ -165,7 +172,7 @@ def __call__(self, text) -> typing.Generator[str]: for w in words: if not w: continue - if w in self.all_punctuation: + if self.punctuation_regex.match(w): continue if clitic_check and w[0] == self.base_clitic_marker == w[-1]: w = w[1:-1] diff --git a/montreal_forced_aligner/dictionary/multispeaker.py b/montreal_forced_aligner/dictionary/multispeaker.py index 9a141723..2b79b1ac 100644 --- a/montreal_forced_aligner/dictionary/multispeaker.py +++ b/montreal_forced_aligner/dictionary/multispeaker.py @@ -138,8 +138,8 @@ class MultispeakerDictionaryMixin(TemporaryDictionaryMixin, metaclass=abc.ABCMet ---------- dictionary_model: :class:`~montreal_forced_aligner.models.DictionaryModel` Dictionary model - dictionary_lookup: dict[int, str] - Mapping of dictionary ids to names + dictionary_lookup: dict[str, int] + Mapping of dictionary names to ids """ def __init__(self, dictionary_path: str = None, **kwargs): @@ -218,10 +218,7 @@ def reversed_word_mapping(self, dictionary_id: int = 1) -> Dict[int, str]: @property def num_dictionaries(self) -> int: """Number of pronunciation dictionaries""" - if self._num_dictionaries is None: - with self.session() as session: - self._num_dictionaries = session.query(Dictionary).count() - return self._num_dictionaries + return len(self.dictionary_lookup) @property def sanitize_function(self) -> MultispeakerSanitizationFunction: @@ -295,7 +292,25 @@ def dictionary_setup(self): flags=re.IGNORECASE, ) self._speaker_ids = getattr(self, "_speaker_ids", {}) + dictionary_id_cache = {} with self.session() as session: + for speaker_id, speaker_name, dictionary_id, dict_name, path in ( + session.query( + Speaker.id, Speaker.name, Dictionary.id, Dictionary.name, Dictionary.path + ) + .join(Speaker.dictionary) + .filter(Dictionary.default == False) # noqa + ): + self._speaker_ids[speaker_name] = speaker_id + dictionary_id_cache[path] = dictionary_id + self.dictionary_lookup[dict_name] = dictionary_id + dictionary = ( + session.query(Dictionary).filter(Dictionary.default == True).first() # noqa + ) + if dictionary: + self._default_dictionary_id = dictionary.id + dictionary_id_cache[dictionary.path] = self._default_dictionary_id + self.dictionary_lookup[dictionary.name] = dictionary.id word_primary_key = 1 pronunciation_primary_key = 1 word_objs = [] @@ -304,13 +319,11 @@ def dictionary_setup(self): phone_counts = collections.Counter() graphemes = set() self._current_speaker_index = getattr(self, "_current_speaker_index", 1) - for speaker, dictionary_model in self.dictionary_model.load_dictionary_paths().items(): - dictionary = ( - session.query(Dictionary) - .filter(Dictionary.path == dictionary_model.path) - .first() - ) - if dictionary is None: + for ( + dictionary_model, + speakers, + ) in self.dictionary_model.load_dictionary_paths().values(): + if dictionary_model.path not in dictionary_id_cache: word_cache = {} pronunciation_cache = set() subsequences = set() @@ -342,7 +355,7 @@ def dictionary_setup(self): clitic_marker=clitic_marker, bracket_regex=bracket_regex.pattern, laughter_regex=laughter_regex.pattern, - default=speaker == "default", + default="default" in speakers, max_disambiguation_symbol=0, silence_word=self.silence_word, oov_word=self.oov_word, @@ -352,6 +365,7 @@ def dictionary_setup(self): ) session.add(dictionary) session.flush() + dictionary_id_cache[dictionary_model.path] = dictionary.id if dictionary.default: self._default_dictionary_id = dictionary.id self._words_mappings[dictionary.id] = {} @@ -548,19 +562,19 @@ def dictionary_setup(self): self.max_disambiguation_symbol = max( self.max_disambiguation_symbol, dictionary.max_disambiguation_symbol ) - - if speaker != "default": - if speaker not in self._speaker_ids: - speaker_objs.append( - { - "id": self._current_speaker_index, - "name": speaker, - "dictionary_id": dictionary.id, - } - ) - self._speaker_ids[speaker] = self._current_speaker_index - self._current_speaker_index += 1 - self.dictionary_lookup[dictionary.id] = dictionary.name + for speaker in speakers: + if speaker != "default": + if speaker not in self._speaker_ids: + speaker_objs.append( + { + "id": self._current_speaker_index, + "name": speaker, + "dictionary_id": dictionary.id, + } + ) + self._speaker_ids[speaker] = self._current_speaker_index + self._current_speaker_index += 1 + self.dictionary_lookup[dictionary.name] = dictionary.id session.commit() self.non_silence_phones -= self.silence_phones diff --git a/montreal_forced_aligner/exceptions.py b/montreal_forced_aligner/exceptions.py index 2f6887f4..6d6a13af 100644 --- a/montreal_forced_aligner/exceptions.py +++ b/montreal_forced_aligner/exceptions.py @@ -25,6 +25,7 @@ "MFAError", "SoxError", "G2PError", + "CtmError", "PyniniAlignmentError", "ConfigError", "LMError", diff --git a/montreal_forced_aligner/models.py b/montreal_forced_aligner/models.py index b6a7ea05..195128d8 100644 --- a/montreal_forced_aligner/models.py +++ b/montreal_forced_aligner/models.py @@ -10,7 +10,7 @@ import shutil import typing from shutil import copy, copyfile, make_archive, move, rmtree, unpack_archive -from typing import TYPE_CHECKING, Collection, Dict, Optional, Union +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union import requests import yaml @@ -1013,7 +1013,7 @@ def name(self) -> str: """Name of the dictionary""" return os.path.splitext(os.path.basename(self.path))[0] - def load_dictionary_paths(self) -> Dict[str, DictionaryModel]: + def load_dictionary_paths(self) -> Dict[str, Tuple[DictionaryModel, List[str]]]: """ Load the pronunciation dictionaries @@ -1027,9 +1027,11 @@ def load_dictionary_paths(self) -> Dict[str, DictionaryModel]: with open(self.path, "r", encoding="utf8") as f: data = yaml.safe_load(f) for speaker, path in data.items(): - mapping[speaker] = DictionaryModel(path) + if path not in mapping: + mapping[path] = (DictionaryModel(path), set()) + mapping[path][1].add(speaker) else: - mapping["default"] = self + mapping[self.path] = (self, {"default"}) return mapping diff --git a/montreal_forced_aligner/utils.py b/montreal_forced_aligner/utils.py index d90e5e09..f183b3db 100644 --- a/montreal_forced_aligner/utils.py +++ b/montreal_forced_aligner/utils.py @@ -405,6 +405,10 @@ def run(self) -> None: """ Run through the arguments in the queue apply the function to them """ + from .config import BLAS_THREADS + + os.environ["OPENBLAS_NUM_THREADS"] = f"{BLAS_THREADS}" + os.environ["MKL_NUM_THREADS"] = f"{BLAS_THREADS}" try: for result in self.function.run(): self.return_q.put(result) diff --git a/tests/test_corpus.py b/tests/test_corpus.py index ec50fee8..a111c1d6 100644 --- a/tests/test_corpus.py +++ b/tests/test_corpus.py @@ -64,7 +64,7 @@ def test_add(basic_corpus_dir, generated_dir): new_utterance = UtteranceData(new_speaker, new_file_name, 0, 1, text="blah blah") assert len(corpus.get_utterances(file=new_file_name, speaker=new_speaker)) == 0 - corpus.add_file(new_file) + corpus.add_file(new_file, session) corpus.add_utterance(new_utterance, session) session.commit()