Skip to content

Commit

Permalink
support multiple gpus while sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
prafullasd committed Jun 1, 2020
1 parent f326b52 commit c6617dd
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ python jukebox/sample.py --model=1b_lyrics --name=sample_1b --levels=3 --sample_
--total_sample_length_in_seconds=180 --sr=44100 --n_samples=16 --hop_fraction=0.5,0.5,0.125
```
The above generates the first `sample_length_in_seconds` seconds of audio from a song of total length `total_sample_length_in_seconds`.
To use multiple GPU's, launch the above scripts as `mpiexec -n {ngpus} python jukebox/sample.py ...` so they use `{ngpus}`

The samples decoded from each level are stored in `{name}/level_{level}`.
You can also view the samples as an html with the aligned lyrics under `{name}/level_{level}/index.html`.
Expand Down
7 changes: 5 additions & 2 deletions jukebox/sample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import numpy as np
import torch as t
import jukebox.utils.dist_adapter as dist

from jukebox.hparams import Hyperparams
from jukebox.utils.torch_utils import empty_cache
Expand Down Expand Up @@ -106,7 +106,10 @@ def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps):
# Decode sample
x = prior.decode(zs[level:], start_level=level, bs_chunks=zs[level].shape[0])

logdir = f"{hps.name}/level_{level}"
if dist.get_world_size() > 1:
logdir = f"{hps.name}_rank_{dist.get_rank()}/level_{level}"
else:
logdir = f"{hps.name}/level_{level}"
if not os.path.exists(logdir):
os.makedirs(logdir)
t.save(dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar")
Expand Down

0 comments on commit c6617dd

Please sign in to comment.