Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* Rework connections again

* Rework pgvector index creation
  • Loading branch information
mmcauliffe committed Feb 9, 2023
1 parent a55a686 commit f5c89bb
Show file tree
Hide file tree
Showing 15 changed files with 72 additions and 81 deletions.
9 changes: 6 additions & 3 deletions montreal_forced_aligner/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,18 @@ def __init__(self, args: MfaArguments):
self.job_name = self.args.job_name
self.log_path = self.args.log_path

def run(self) -> typing.Generator:
"""Run the function, calls subclassed object's ``_run`` with error handling"""
self.db_engine = sqlalchemy.create_engine(
def db_engine(self):

return sqlalchemy.create_engine(
self.db_string,
poolclass=sqlalchemy.NullPool,
isolation_level="AUTOCOMMIT",
logging_name=f"{type(self).__name__}_engine",
pool_reset_on_return=None,
).execution_options(logging_token=f"{type(self).__name__}_engine")

def run(self) -> typing.Generator:
"""Run the function, calls subclassed object's ``_run`` with error handling"""
try:
yield from self._run()
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion montreal_forced_aligner/acoustic_modeling/monophone.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, args: MonoAlignEqualArguments):
def _run(self) -> typing.Generator[typing.Tuple[int, int]]:
"""Run the function"""

with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine) as session:
with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine()) as session:
job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
Expand Down
2 changes: 1 addition & 1 deletion montreal_forced_aligner/acoustic_modeling/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, args: TransitionAccArguments):
def _run(self) -> typing.Generator[typing.Tuple[int, str]]:
"""Run the function"""

with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine) as session:
with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine()) as session:
job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
Expand Down
12 changes: 6 additions & 6 deletions montreal_forced_aligner/alignment/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def __init__(self, args: CompileTrainGraphsArguments):
def _run(self) -> typing.Generator[typing.Tuple[int, int]]:
"""Run the function"""

with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine) as session:
with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine()) as session:
job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
Expand Down Expand Up @@ -825,7 +825,7 @@ def __init__(self, args: AlignArguments):
def _run(self) -> typing.Generator[typing.Tuple[int, float]]:
"""Run the function"""

with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine) as session:
with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine()) as session:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
Expand Down Expand Up @@ -1119,7 +1119,7 @@ def setup_files(

def _run(self) -> typing.Generator[typing.Tuple[int, float]]:
"""Run the function"""
with Session(self.db_engine) as session, mfa_open(self.log_path, "w") as log_file:
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
Expand Down Expand Up @@ -1446,7 +1446,7 @@ def __init__(self, args: PhoneConfidenceArguments):

def _run(self) -> typing.Generator[typing.Tuple[int, str]]:
"""Run the function"""
with Session(self.db_engine) as session:
with Session(self.db_engine()) as session:
utterances = (
session.query(Utterance)
.filter(Utterance.job_id == self.job_name)
Expand Down Expand Up @@ -1599,7 +1599,7 @@ def _process_pronunciations(
def _run(self) -> typing.Generator[typing.Tuple[int, int, str]]:
"""Run the function"""
self.phone_symbol_table = None
with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine) as session:
with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine()) as session:
job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
Expand Down Expand Up @@ -1987,7 +1987,7 @@ def _run(self) -> typing.Generator[typing.Tuple[int, List[CtmInterval], List[Ctm
"""Run the function"""
align_lexicon_paths = {}
self.phone_symbol_table = None
with Session(self.db_engine) as session, mfa_open(self.log_path, "w") as log_file:
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
Expand Down
2 changes: 1 addition & 1 deletion montreal_forced_aligner/command_line/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def configure_pg(directory):
"#maintenance_work_mem = 64MB": "maintenance_work_mem = 500MB",
"#work_mem = 4MB": "work_mem = 128MB",
"shared_buffers = 128MB": "shared_buffers = 256MB",
"max_connections = 100": "max_connections = 300",
"max_connections = 100": "max_connections = 10000",
}
with mfa_open(os.path.join(directory, "postgresql.conf"), "r") as f:
config = f.read()
Expand Down
10 changes: 5 additions & 5 deletions montreal_forced_aligner/corpus/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def __init__(self, args: MfccArguments):

def _run(self) -> typing.Generator[int]:
"""Run the function"""
with Session(self.db_engine) as session, mfa_open(self.log_path, "w") as log_file:
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = session.get(Job, self.job_name)
feats_scp_path = job.construct_path(self.data_directory, "feats", "scp")
pitch_scp_path = job.construct_path(self.data_directory, "pitch", "scp")
Expand Down Expand Up @@ -589,7 +589,7 @@ def __init__(self, args: FinalFeatureArguments):

def _run(self) -> typing.Generator[int]:
"""Run the function"""
with Session(self.db_engine) as session, mfa_open(self.log_path, "w") as log_file:
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = session.get(Job, self.job_name)
feats_scp_path = job.construct_path(self.data_directory, "feats", "scp")
temp_scp_path = job.construct_path(self.data_directory, "final_features", "scp")
Expand Down Expand Up @@ -734,7 +734,7 @@ def __init__(self, args: PitchArguments):

def _run(self) -> typing.Generator[int]:
"""Run the function"""
with Session(self.db_engine) as session, mfa_open(self.log_path, "w") as log_file:
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = session.get(Job, self.job_name)

feats_scp_path = job.construct_path(self.data_directory, "pitch", "scp")
Expand Down Expand Up @@ -802,7 +802,7 @@ def __init__(self, args: PitchRangeArguments):

def _run(self) -> typing.Generator[int]:
"""Run the function"""
with Session(self.db_engine) as session, mfa_open(self.log_path, "w") as log_file:
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = session.get(Job, self.job_name)
wav_path = job.construct_path(self.data_directory, "wav", "scp")
segment_path = job.construct_path(self.data_directory, "segments", "scp")
Expand Down Expand Up @@ -1500,7 +1500,7 @@ def _run(self) -> typing.Generator[str]:
"""Run the function"""
if os.path.exists(self.ivectors_scp_path):
return
with Session(self.db_engine) as session, mfa_open(self.log_path, "w") as log_file:
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True))
Expand Down
18 changes: 18 additions & 0 deletions montreal_forced_aligner/corpus/ivector_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List

import numpy as np
import sqlalchemy
import tqdm

from montreal_forced_aligner.config import GLOBAL_CONFIG, IVECTOR_DIMENSION
Expand Down Expand Up @@ -359,6 +360,12 @@ def collect_utterance_ivectors(self) -> None:
}
pbar.update(1)
bulk_update(session, Utterance, list(update_mapping.values()))
session.flush()
session.execute(
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS utterance_ivector_index ON utterance USING ivfflat (ivector vector_cosine_ops);"
)
)
session.query(Corpus).update({Corpus.ivectors_calculated: True})
session.commit()
self._write_ivectors()
Expand Down Expand Up @@ -415,4 +422,15 @@ def collect_speaker_ivectors(self) -> None:
for i, speaker_id in enumerate(speaker_ids):
update_mapping[speaker_id]["plda_vector"] = ivectors[i, :]
bulk_update(session, Speaker, list(update_mapping.values()))
session.flush()
session.execute(
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS speaker_ivector_index ON speaker USING ivfflat (ivector vector_cosine_ops);"
)
)
session.execute(
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS speaker_plda_vector_index ON speaker USING ivfflat (plda_vector vector_cosine_ops);"
)
)
session.commit()
4 changes: 2 additions & 2 deletions montreal_forced_aligner/corpus/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def _no_dictionary_sanitize(self, session):
def _run(self) -> typing.Generator[typing.Tuple[int, float]]:
"""Run the function"""
self.compile_regexes()
with Session(self.db_engine) as session:
with Session(self.db_engine()) as session:
dict_count = session.query(Dictionary).join(Dictionary.words).limit(1).count()
if self.use_g2p or dict_count > 0:
yield from self._dictionary_sanitize(session)
Expand Down Expand Up @@ -771,7 +771,7 @@ def output_to_directory(self, session) -> None:

def _run(self) -> typing.Generator[typing.Tuple[int, float]]:
"""Run the function"""
with Session(self.db_engine) as session:
with Session(self.db_engine()) as session:
if self.for_features:
yield from self.output_for_features(session)
else:
Expand Down
45 changes: 0 additions & 45 deletions montreal_forced_aligner/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,30 +880,6 @@ class Speaker(MfaSqlBase):
utterances = relationship("Utterance", back_populates="speaker")
files = relationship("File", secondary=SpeakerOrdering, back_populates="speakers")

__table_args__ = (
sqlalchemy.Index(
"speaker_ivector_index",
"ivector",
postgresql_using="ivfflat",
postgresql_with={"lists": 100},
postgresql_ops={"ivector": "vector_cosine_ops"},
),
sqlalchemy.Index(
"speaker_xvector_index",
"xvector",
postgresql_using="ivfflat",
postgresql_with={"lists": 100},
postgresql_ops={"xvector": "vector_cosine_ops"},
),
sqlalchemy.Index(
"speaker_plda_vector_index",
"plda_vector",
postgresql_using="ivfflat",
postgresql_with={"lists": 100},
postgresql_ops={"plda_vector": "vector_cosine_ops"},
),
)


class File(MfaSqlBase):
"""
Expand Down Expand Up @@ -1335,27 +1311,6 @@ class Utterance(MfaSqlBase):
postgresql_ops={"text": "gin_trgm_ops"},
postgresql_using="gin",
),
sqlalchemy.Index(
"utterance_ivector_index",
"ivector",
postgresql_using="ivfflat",
postgresql_with={"lists": 100},
postgresql_ops={"ivector": "vector_cosine_ops"},
),
sqlalchemy.Index(
"utterance_xvector_index",
"xvector",
postgresql_using="ivfflat",
postgresql_with={"lists": 100},
postgresql_ops={"xvector": "vector_cosine_ops"},
),
sqlalchemy.Index(
"utterance_plda_vector_index",
"plda_vector",
postgresql_using="ivfflat",
postgresql_with={"lists": 100},
postgresql_ops={"plda_vector": "vector_cosine_ops"},
),
)

def __repr__(self) -> str:
Expand Down
6 changes: 3 additions & 3 deletions montreal_forced_aligner/diarization/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int, int]]:
for line in input_proc.stdout:
lines.append(line)
input_proc.wait()
with Session(self.db_engine) as session:
with Session(self.db_engine()) as session:

job: Job = (
session.query(Job)
Expand Down Expand Up @@ -593,7 +593,7 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int, int]]:
else:
columns = [Utterance.id, Utterance.speaker_id, Utterance.plda_vector]
filter = Utterance.plda_vector != None # noqa
with Session(self.db_engine) as session:
with Session(self.db_engine()) as session:
speakers = (
session.query(Speaker.id)
.join(Speaker.utterances)
Expand Down Expand Up @@ -681,7 +681,7 @@ def _run(self) -> typing.Generator[typing.Tuple[int, int, int]]:
run_opts=run_opts,
)
device = torch.device("cuda" if self.cuda else "cpu")
with Session(self.db_engine) as session:
with Session(self.db_engine()) as session:

job: Job = (
session.query(Job)
Expand Down
23 changes: 19 additions & 4 deletions montreal_forced_aligner/diarization/speaker_diarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,11 +735,13 @@ def initialize_mfa_clustering(self):
session.commit()
bulk_update(session, Utterance, utterance_mapping)
session.execute(
sqlalchemy.text("CREATE INDEX ix_utterance_speaker_id on utterance(speaker_id)")
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS ix_utterance_speaker_id on utterance(speaker_id)"
)
)
session.execute(
sqlalchemy.text(
'CREATE INDEX utterance_position_index on utterance(file_id, speaker_id, begin, "end", channel)'
'CREATE INDEX IF NOT EXISTS utterance_position_index on utterance(file_id, speaker_id, begin, "end", channel)'
)
)
session.commit()
Expand Down Expand Up @@ -839,11 +841,13 @@ def classify_iteration(self, iteration=None) -> None:
session.commit()
bulk_update(session, Utterance, utterance_mapping)
session.execute(
sqlalchemy.text("CREATE INDEX ix_utterance_speaker_id on utterance(speaker_id)")
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS ix_utterance_speaker_id on utterance(speaker_id)"
)
)
session.execute(
sqlalchemy.text(
'CREATE INDEX utterance_position_index on utterance(file_id, speaker_id, begin, "end", channel)'
'CREATE INDEX IF NOT EXISTS utterance_position_index on utterance(file_id, speaker_id, begin, "end", channel)'
)
)
session.commit()
Expand Down Expand Up @@ -1337,6 +1341,17 @@ def load_embeddings(self) -> None:
for v in update_mapping.values():
v["plda_vector"] = v["xvector"]
bulk_update(session, Utterance, list(update_mapping.values()))
session.flush()
session.execute(
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS utterance_xvector_index ON utterance USING ivfflat (xvector vector_cosine_ops);"
)
)
session.execute(
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS utterance_plda_vector_index ON utterance USING ivfflat (plda_vector vector_cosine_ops);"
)
)
session.query(Corpus).update({Corpus.xvectors_loaded: True})
session.commit()
logger.debug(f"Loading embeddings took {time.time() - begin:.3f} seconds")
Expand Down
8 changes: 4 additions & 4 deletions montreal_forced_aligner/ivector/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _run(self) -> typing.Generator[None]:
"""Run the function"""
if os.path.exists(self.gselect_path):
return
with Session(self.db_engine) as session, mfa_open(self.log_path, "w") as log_file:
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True))
Expand Down Expand Up @@ -171,7 +171,7 @@ def _run(self) -> typing.Generator[None]:
modified_posterior_scale = (
self.ivector_options["posterior_scale"] * self.ivector_options["subsample"]
)
with Session(self.db_engine) as session, mfa_open(self.log_path, "w") as log_file:
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True))
Expand Down Expand Up @@ -248,7 +248,7 @@ def __init__(self, args: AccGlobalStatsArguments):

def _run(self) -> typing.Generator[None]:
"""Run the function"""
with Session(self.db_engine) as session, mfa_open(self.log_path, "w") as log_file:
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True))
Expand Down Expand Up @@ -313,7 +313,7 @@ def __init__(self, args: AccIvectorStatsArguments):

def _run(self) -> typing.Generator[None]:
"""Run the function"""
with Session(self.db_engine) as session, mfa_open(self.log_path, "w") as log_file:
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True))
Expand Down
Loading

0 comments on commit f5c89bb

Please sign in to comment.