Skip to content

Commit

Permalink
cmd: added import function to samples & checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
nhamilakis committed Jul 20, 2023
1 parent a540aa7 commit 0a80ddb
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 32 deletions.
57 changes: 53 additions & 4 deletions zerospeech/cmd/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import argparse
import sys
from pathlib import Path

from rich.padding import Padding
from rich.table import Table

from zerospeech.generics import checkpoints
from zerospeech.misc import md5sum, extract
from zerospeech.networkio import check_update_repo_index, update_repo_index
from zerospeech.out import console, error_console
from zerospeech.out import console, error_console, void_console, warning_console
from zerospeech.settings import get_settings
from .cli_lib import CMD

Expand Down Expand Up @@ -64,9 +67,55 @@ def run(self, argv: argparse.Namespace):
if check_update_repo_index():
update_repo_index()

datasets = checkpoints.CheckpointDir.load()
dataset = datasets.get(argv.name, cls=checkpoints.CheckPointItem)
dataset.pull(quiet=argv.quiet, show_progress=True, verify=not argv.skip_verification)
chkpt_dir = checkpoints.CheckpointDir.load()
chkpt = chkpt_dir.get(argv.name, cls=checkpoints.CheckPointItem)
chkpt.pull(quiet=argv.quiet, show_progress=True, verify=not argv.skip_verification)


class ImportCheckpointCMD(CMD):
""" Import checkpoints from a zip archive """
COMMAND = "import"
NAMESPACE = "checkpoints"

def init_parser(self, parser: argparse.ArgumentParser):
parser.add_argument("zip_file")
parser.add_argument('-u', '--skip-verification', action='store_true',
help='Do not check hash in repo index.')
parser.add_argument('-q', '--quiet', action='store_true',
help='Suppress download info output')

def run(self, argv: argparse.Namespace):
# update repo index if necessary
if check_update_repo_index():
update_repo_index()

chkpt_dir = checkpoints.CheckpointDir.load()
archive = Path(argv.zip_file)
std_out = console
if argv.quiet:
std_out = void_console

if not archive.is_file() and archive.suffix != '.zip':
error_console.print(f'Given archive ({archive}) does not exist or is not a valid zip archive !!!')
sys.exit(1)

if not argv.skip_verification:
with std_out.status(f'Hashing {archive.name}'):
md5hash = md5sum(archive)
item = chkpt_dir.find_by_hash(md5hash)
if item is None:
error_console.print(f'Archive {archive.name} does not correspond to a registered checkpoint archive')
sys.exit(1)
name = item.name
std_out.print(f"[green]Checkpoint {name} detected")
else:
name = archive.stem
warning_console.print(f"Importing {name} without checking, could be naming/file mismatch")

with std_out.status(f"Unzipping {name}..."):
extract(archive, chkpt_dir.root_dir / name)

std_out.print(f"[green]Checkpoint {name} installed successfully !!")


class RemoveCheckpointCMD(CMD):
Expand Down
4 changes: 2 additions & 2 deletions zerospeech/cmd/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rich.table import Table

from zerospeech.datasets import DatasetsDir, Dataset
from zerospeech.misc import md5sum, unzip
from zerospeech.misc import md5sum, extract
from zerospeech.networkio import check_update_repo_index, update_repo_index
from zerospeech.out import console, error_console, warning_console, void_console
from zerospeech.settings import get_settings
Expand Down Expand Up @@ -111,7 +111,7 @@ def run(self, argv: argparse.Namespace):

# unzip dataset
with std_out.status(f"Unzipping {name}..."):
unzip(archive, datasets_dir.root_dir / name)
extract(archive, datasets_dir.root_dir / name)

std_out.print(f"[green]Dataset {name} installed successfully !!")

Expand Down
58 changes: 54 additions & 4 deletions zerospeech/cmd/samples.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@

import argparse
import sys
from pathlib import Path

from rich.padding import Padding
from rich.table import Table

from zerospeech.generics import samples
from zerospeech.misc import md5sum, extract
from zerospeech.networkio import check_update_repo_index, update_repo_index
from zerospeech.out import console, error_console
from zerospeech.out import console, error_console, void_console, warning_console
from zerospeech.settings import get_settings
from .cli_lib import CMD

Expand Down Expand Up @@ -70,6 +72,54 @@ def run(self, argv: argparse.Namespace):
sample_itm.pull(quiet=argv.quiet, show_progress=True, verify=not argv.skip_verification)


class ImportSamples(CMD):
""" Import a sample from a zip archive """
COMMAND = "import"
NAMESPACE = "samples"

def init_parser(self, parser: argparse.ArgumentParser):
parser.add_argument("zip_file")
parser.add_argument('-u', '--skip-verification', action='store_true',
help='Do not check hash in repo index.')
parser.add_argument('-q', '--quiet', action='store_true',
help='Suppress download info output')

def run(self, argv: argparse.Namespace):
# update repo index if necessary
if check_update_repo_index():
update_repo_index()

sample_dir = samples.SamplesDir.load()
archive = Path(argv.zip_file)
std_out = console
if argv.quiet:
std_out = void_console

if not archive.is_file() and archive.suffix != '.zip':
error_console.print(f'Given archive ({archive}) does not exist or is not a valid zip archive !!!')
sys.exit(1)

if not argv.skip_verification:
with std_out.status(f'Hashing {archive.name}'):
md5hash = md5sum(archive)
item = sample_dir.find_by_hash(md5hash)
if item is None:
error_console.print(f'Archive {archive.name} does not correspond to a registered sample')
sys.exit(1)
name = item.name
std_out.print(f"[green]Sample {name} detected")

else:
name = archive.stem
warning_console.print(f"Importing {name} without checking, could be naming/file mismatch")

# unzip sample
with std_out.status(f"Unzipping {name}..."):
extract(archive, sample_dir.root_dir / name)

std_out.print(f"[green]Sample {name} installed successfully !!")


class RemoveSampleCMD(CMD):
""" Remove a sample item """
COMMAND = "rm"
Expand All @@ -83,6 +133,6 @@ def run(self, argv: argparse.Namespace):
smp = sample_dir.get(argv.name)
if smp:
smp.uninstall()
console.log("[green] Dataset uninstalled successfully !")
console.log("[green] Sample uninstalled successfully !")
else:
error_console.log(f"Failed to find dataset named :{argv.name}")
error_console.log(f"Failed to find sample named :{argv.name}")
8 changes: 4 additions & 4 deletions zerospeech/datasets/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from zerospeech.generics import (
RepoItemDir, ImportableItem, DownloadableItem, Namespace, Subset
)
from zerospeech.misc import download_extract_zip, md5sum, unzip
from zerospeech.misc import extract, download_extract_archive
from zerospeech.out import console
from zerospeech.settings import get_settings

Expand Down Expand Up @@ -83,15 +83,15 @@ def pull(self, *, verify: bool = True, quiet: bool = False, show_progress: bool
md5_hash = self.origin.md5sum

# download & extract archive
download_extract_zip(self.origin.zip_url, self.location, int(self.origin.total_size),
filename=self.name, md5sum_hash=md5_hash, quiet=quiet, show_progress=show_progress)
download_extract_archive(self.origin.zip_url, self.location, int(self.origin.total_size),
filename=self.name, md5sum_hash=md5_hash, quiet=quiet, show_progress=show_progress)
if not quiet:
console.print(f"[green]Dataset {self.name} installed successfully !!")

def import_zip(self, *, archive: Path):
""" Import dataset from an archive """
# extract archive
unzip(archive, self.location)
extract(archive, self.location)


class DatasetsDir(RepoItemDir):
Expand Down
10 changes: 5 additions & 5 deletions zerospeech/generics/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import ClassVar, Type

from ..misc import download_extract_zip
from ..out import console
from zerospeech.misc import download_extract_archive
from zerospeech.out import console
from zerospeech.settings import get_settings
from .repository import DownloadableItem, RepoItemDir, RepositoryItemType
from ..settings import get_settings

st = get_settings()

Expand All @@ -17,8 +17,8 @@ def pull(self, *, verify: bool = True, quiet: bool = False, show_progress: bool
md5_hash = self.origin.md5sum

# download & extract archive
download_extract_zip(self.origin.zip_url, self.location, int(self.origin.total_size),
filename=self.name, md5sum_hash=md5_hash, quiet=quiet, show_progress=show_progress)
download_extract_archive(self.origin.zip_url, self.location, int(self.origin.total_size),
filename=self.name, md5sum_hash=md5_hash, quiet=quiet, show_progress=show_progress)
if not quiet:
console.print(f"[green]Checkpoint set {self.name} installed successfully !!")

Expand Down
10 changes: 5 additions & 5 deletions zerospeech/generics/samples.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import ClassVar, Type

from zerospeech.misc import download_extract_archive
from zerospeech.out import console
from zerospeech.settings import get_settings
from .repository import DownloadableItem, RepositoryItemType, RepoItemDir
from ..misc import download_extract_zip
from ..out import console
from ..settings import get_settings

st = get_settings()

Expand All @@ -17,8 +17,8 @@ def pull(self, *, verify: bool = True, quiet: bool = False, show_progress: bool
md5_hash = self.origin.md5sum

# download & extract archive
download_extract_zip(self.origin.zip_url, self.location, int(self.origin.total_size),
filename=self.name, md5sum_hash=md5_hash, quiet=quiet, show_progress=show_progress)
download_extract_archive(self.origin.zip_url, self.location, int(self.origin.total_size),
filename=self.name, md5sum_hash=md5_hash, quiet=quiet, show_progress=show_progress)
if not quiet:
console.print(f"[green]Sample {self.name} installed successfully !!")

Expand Down
42 changes: 36 additions & 6 deletions zerospeech/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import contextlib
import io
import json
import re
import sys
import tarfile
import threading
import urllib.parse
from pathlib import Path
from typing import Dict, List, Union, Optional, Protocol
from zipfile import ZipFile
Expand Down Expand Up @@ -150,20 +153,47 @@ def unzip(archive: Path, output: Path):
zipObj.extractall(output)


def untar(archive: Path, output: Path):
""" Extract a tar archive (supports gzipped format) into the output directory"""
# create folder if it does not exist
output.mkdir(exist_ok=True, parents=True)
# Open & extract
with tarfile.open(archive, 'r') as tar:
tar.extractall(path=output)


def extract(archive: Path, output: Path):
""" Extract an archive into the output directory """
if archive.suffix in ('.zip',):
unzip(archive, output)
elif archive.suffix in ('.tar', '.gz', '.tgz', '.bz2', '.tbz2', '.xz', '.txz'):
untar(archive, output)
else:
raise ValueError(f'{archive.suffix}: Unsupported archive format')


def zip_folder(archive_file: Path, location: Path):
""" Create a zip archive from a folder """
with ZipFile(archive_file, 'w') as zip_obj:
for file in filter(lambda x: x.is_file(), location.rglob("*")):
zip_obj.write(file, str(file.relative_to(location)))


def download_extract_zip(
zip_url: str, target_location: Path, size_in_bytes: int, *, filename: str = "",
def get_request_filename(response: requests.Response) -> str:
""" Get filename from response """
if "Content-Disposition" in response.headers.keys():
return re.findall("filename=(.+)", response.headers["Content-Disposition"])[0]
else:
return Path(urllib.parse.unquote(response.url)).name


def download_extract_archive(
archive_url: str, target_location: Path, size_in_bytes: int, *, filename: str = "",
md5sum_hash: str = "", quiet: bool = False, show_progress: bool = True,
):
tmp_dir = st.mkdtemp()
response = requests.get(zip_url, stream=True)
tmp_filename = tmp_dir / "download.zip"
response = requests.get(archive_url, stream=True)
tmp_filename = tmp_dir / get_request_filename(response)

if quiet:
_console = void_console
Expand Down Expand Up @@ -193,8 +223,8 @@ def download_extract_zip(
_console.print("[green]MD5sum Failed, Check with repository administrator.\nExiting...")
sys.exit(1)

with _console.status("[red]Unzipping archive..."):
unzip(tmp_filename, target_location)
with _console.status("[red]Extracting archive..."):
extract(tmp_filename, target_location)


def symlink_dir_contents(source: Path, dest: Path):
Expand Down
4 changes: 2 additions & 2 deletions zerospeech/networkio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def update_repo_index():
sys.exit(1)

with st.repository_index.open('w') as fp:
json.dump(data, fp)
json.dump(data, fp, indent=4)
console.log("RepositoryIndex has been updated successfully !!")


Expand Down Expand Up @@ -56,6 +56,6 @@ def check_update_repo_index() -> bool:
last_update_local = datetime.fromisoformat(json.load(fp).get('last_modified'))
except ValueError:
warnings.warn("Local index missing or corrupted !!!")
return False
return True

return last_update_online > last_update_local

0 comments on commit 0a80ddb

Please sign in to comment.