Skip to content

Commit

Permalink
Fix for some pathlib issues (MontrealCorpusTools#564)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcauliffe committed Feb 14, 2023
1 parent 703167e commit 5304068
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 17 deletions.
12 changes: 6 additions & 6 deletions montreal_forced_aligner/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,9 @@ def setup(self) -> None:
self.initialize_database()

@property
def working_directory(self) -> str:
def working_directory(self) -> Path:
"""Alias for a folder that contains worker information, separate from the data directory"""
return os.path.join(self.output_directory, self._current_workflow)
return self.output_directory.joinpath(self._current_workflow)

@classmethod
def parse_args(
Expand Down Expand Up @@ -712,14 +712,14 @@ def identifier(self) -> str:
return self.data_source_identifier

@property
def output_directory(self) -> str:
def output_directory(self) -> Path:
"""Root temporary directory to store all of this worker's files"""
return os.path.join(GLOBAL_CONFIG.temporary_directory, self.identifier)
return GLOBAL_CONFIG.current_profile.temporary_directory.joinpath(self.identifier)

@property
def log_file(self) -> str:
def log_file(self) -> Path:
"""Path to the worker's log file"""
return os.path.join(self.output_directory, f"{self.data_source_identifier}.log")
return self.output_directory.joinpath(f"{self.data_source_identifier}.log")

def setup_logger(self) -> None:
"""
Expand Down
22 changes: 13 additions & 9 deletions montreal_forced_aligner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from __future__ import annotations

import os
import pathlib
import re
import typing
from pathlib import Path
from typing import Any, Dict, List, Union

import click
Expand Down Expand Up @@ -38,7 +38,7 @@
PLDA_DIMENSION = 192


def get_temporary_directory() -> Path:
def get_temporary_directory() -> pathlib.Path:
"""
Get the root temporary directory for MFA
Expand All @@ -51,15 +51,17 @@ def get_temporary_directory() -> Path:
------
:class:`~montreal_forced_aligner.exceptions.RootDirectoryError`
"""
TEMP_DIR = os.environ.get(MFA_ROOT_ENVIRONMENT_VARIABLE, os.path.expanduser("~/Documents/MFA"))
TEMP_DIR = pathlib.Path(
os.environ.get(MFA_ROOT_ENVIRONMENT_VARIABLE, os.path.expanduser("~/Documents/MFA"))
)
try:
os.makedirs(TEMP_DIR, exist_ok=True)
TEMP_DIR.mkdir(parents=True, exist_ok=True)
except OSError:
raise RootDirectoryError(TEMP_DIR, MFA_ROOT_ENVIRONMENT_VARIABLE)
return Path(TEMP_DIR)
return TEMP_DIR


def generate_config_path() -> Path:
def generate_config_path() -> pathlib.Path:
"""
Generate the global configuration path for MFA
Expand All @@ -71,7 +73,7 @@ def generate_config_path() -> Path:
return get_temporary_directory().joinpath("global_config.yaml")


def generate_command_history_path() -> Path:
def generate_command_history_path() -> pathlib.Path:
"""
Generate the path to the command history file
Expand Down Expand Up @@ -147,7 +149,7 @@ class MfaProfile:
blas_num_threads: int = 1
use_mp: bool = True
single_speaker: bool = False
temporary_directory: str = get_temporary_directory()
temporary_directory: pathlib.Path = get_temporary_directory()
github_token: typing.Optional[str] = None

def __getitem__(self, item):
Expand All @@ -166,6 +168,8 @@ def update(self, data: Union[Dict[str, Any], click.Context]) -> None:
for k, v in data.items():
if k == "temp_directory":
k = "temporary_directory"
if k == "temporary_directory":
v = pathlib.Path(v)
if v is None:
continue
if hasattr(self, k):
Expand Down Expand Up @@ -220,7 +224,7 @@ def save(self) -> None:
def load(self) -> None:
"""Load MFA configuration"""
with mfa_open(self.config_path, "r") as f:
data = yaml.safe_load(f)
data = yaml.load(f, Loader=yaml.Loader)
for name, p in data.pop("profiles", {}).items():
self.profiles[name] = MfaProfile()
self.profiles[name].update(p)
Expand Down
2 changes: 1 addition & 1 deletion montreal_forced_aligner/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def parse_old_features(config: MetaDict) -> MetaDict:
return config


def configure_logger(identifier: str, log_file: Optional[str] = None) -> None:
def configure_logger(identifier: str, log_file: Optional[Path] = None) -> None:
"""
Configure logging for the given identifier
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import pathlib
import shutil

import mock
Expand Down Expand Up @@ -85,7 +86,7 @@ def global_config():
@pytest.fixture(scope="session")
def temp_dir(generated_dir, global_config):
temp_dir = os.path.join(generated_dir, "temp")
global_config.current_profile.temporary_directory = temp_dir
global_config.current_profile.temporary_directory = pathlib.Path(temp_dir)
global_config.save()
yield temp_dir

Expand Down
12 changes: 12 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pathlib

import pytest

Expand Down Expand Up @@ -168,3 +169,14 @@ def test_load(basic_corpus_dir, basic_dict_path, temp_dir, config_directory):
with pytest.raises(ConfigError):
params = TrainableAligner.parse_parameters(path)
am_trainer.cleanup()


def test_config(global_config):
new_temp_path = global_config.current_profile.temporary_directory.joinpath("test")
global_config.current_profile.temporary_directory = new_temp_path
global_config.save()
global_config.load()
assert isinstance(global_config.current_profile.temporary_directory, pathlib.Path)
assert global_config.current_profile.temporary_directory == new_temp_path
global_config.current_profile.temporary_directory = new_temp_path.parent
global_config.save()

0 comments on commit 5304068

Please sign in to comment.