Skip to content

Commit

Permalink
Merge branch 'master' into doc_train_finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
prafullasd authored May 14, 2020
2 parents d89f2e3 + 1e386f2 commit 534e5d3
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 25 deletions.
60 changes: 41 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@ Code for "Jukebox: A Generative Model for Music"
[Paper](https://arxiv.org/abs/2005.00341)
[Blog](https://openai.com/blog/jukebox)
[Explorer](http://jukebox.openai.com/)
[Colab](https://colab.research.google.com/github/openai/jukebox/blob/master/jukebox/Interacting_with_Jukebox.ipynb)
[Colab](https://colab.research.google.com/github/openai/jukebox/blob/master/jukebox/Interacting_with_Jukebox.ipynb)

# Install
Install the conda package manager from https://docs.conda.io/en/latest/miniconda.html

```
# Required: Sampling
conda create --name jukebox python=3.7.5
conda activate jukebox
conda install mpi4py=3.0.3 # if this fails, try: pip install mpi4py==3.0.3
conda install pytorch=1.4 torchvision=0.5 cudatoolkit=10.0 -c pytorch
pip install mpi4py==3.0.3
git clone https://github.com/openai/jukebox.git
cd jukebox
pip install -r requirements.txt
Expand Down Expand Up @@ -117,7 +119,7 @@ We pass `sample_length = n_ctx * downsample_of_level` so that after downsampling
Here, `n_ctx = 8192` and `downsamples = (32, 256)`, giving `sample_lengths = (8192 * 32, 8192 * 256) = (65536, 2097152)` respectively for the bottom and top level.

### Reuse pre-trained VQ-VAE and train top-level prior on new dataset from scratch.
#### No labels
#### Train without labels
Our pre-trained VQ-VAE can produce compressed codes for a wide variety of genres of music, and the pre-trained upsamplers
can upsample them back to audio that sound very similar to the original audio.
To re-use these for a new dataset of your choice, you can retrain just the top-level
Expand All @@ -129,43 +131,63 @@ mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,small_prior,all_fp16,cpu_
--labels=False --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000 \
--lr_use_linear_decay --lr_decay={decay_steps_as_needed}
```
You can then run sample.py with the top-level of our models replaced by your new model. To do so, add an entry `my_model` in MODELs (in `make_models.py`) with the (vqvae hps, upsampler hps, top-level prior hps) of your new model, and run sample.py with `--model=my_model`.

Training the `small_prior` with a batch size of 2, 4, and 8 requires 6.7 GB, 9.3 GB, and 15.8 GB of GPU memory, respectively. A few days to a week of training typically yields reasonable samples when the dataset is homogeneous (e.g. all piano pieces, songs of the same style, etc).

#### With labels
To train with you own metadata for your audio files, pass `--labels=True --labels_v3=True` and implement `get_metadata` in `data/files_dataset.py` to return the `artist`, `genre` and `lyrics` for a given audio file. For now, you can pass `''` for lyrics to not use any lyrics.
#### Sample from new model
You can then run sample.py with the top-level of our models replaced by your new model. To do so,
- Add an entry `my_model=("vqvae", "upsampler_level_0", "upsampler_level_1", "small_prior")` in `MODELS` in `make_models.py`.
- Update the `small_prior` dictionary in `hparams.py` to include `restore_prior='path/to/checkpoint'`. If you
you changed any hps directly in the command line script (eg:`heads`), make sure to update them in the dictionary too so
that `make_models` restores our checkpoint correctly.
- Run sample.py as outlined in the sampling section, but now with `--model=my_model`

#### Train with labels
To train with you own metadata for your audio files, implement `get_metadata` in `data/files_dataset.py` to return the
`artist`, `genre` and `lyrics` for a given audio file. For now, you can pass `''` for lyrics to not use any lyrics.

For training with labels, we'll use `small_labelled_prior` in `hparams.py`, and we set `labels=True,labels_v3=True`.
We use 2 kinds of labels information:
- Artist/Genre:
- For each file, we return an artist_id and a list of genre_ids. The reason we have a list and not a single genre_id is that in v2, we split genres like `blues_rock` into a bag of words `[blues, rock]`, and we pass atmost `max_bow_genre_size` of those, in `v3` we consider it as a single word and just set `max_bow_genre_size=1`.
- Update the `v3_artist_ids` and `v3_genre_ids` to use ids from your new dataset. Pass the hps `y_bins = (number_of_genres, number_of_artists)` and `max_bow_genre_size=1`.
- For each file, we return an artist_id and a list of genre_ids. The reason we have a list and not a single genre_id
is that in v2, we split genres like `blues_rock` into a bag of words `[blues, rock]`, and we pass atmost
`max_bow_genre_size` of those, in `v3` we consider it as a single word and just set `max_bow_genre_size=1`.
- Update the `v3_artist_ids` and `v3_genre_ids` to use ids from your new dataset.
- In `small_labelled_prior`, set the hps `y_bins = (number_of_genres, number_of_artists)` and `max_bow_genre_size=1`.
- Timing:
- For each chunk of audio, we return the `total_length` of the song, the `offset` the current audio chunk is at and the `sample_length` of the audio chunk. We have three timing embeddings: total_length, our current position, and our current position as a fraction of the total length, and we divide the range of these values into `t_bins` discrete bins.
- Pass the hps `min_duration` and `max_duration` to be the shortest/longest duration of audio files you want for your dataset, and `t_bins` for how many bins you want to discretize timing information into. Note `min_duration * sr` needs to be at least `sample_length` to have an audio chunk in it.
- For each chunk of audio, we return the `total_length` of the song, the `offset` the current audio chunk is at and
the `sample_length` of the audio chunk. We have three timing embeddings: total_length, our current position, and our
current position as a fraction of the total length, and we divide the range of these values into `t_bins` discrete bins.
- In `small_labelled_prior`, set the hps `min_duration` and `max_duration` to be the shortest/longest duration of audio
files you want for your dataset, and `t_bins` for how many bins you want to discretize timing information into. Note
`min_duration * sr` needs to be at least `sample_length` to have an audio chunk in it.

After these modifications, to train a top-level with labels, run
```
mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,small_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_prior_labels \
mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,small_labelled_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_prior_labels \
--sample_length=1048576 --bs=4 --aug_shift --aug_blend --audio_files_dir={audio_files_dir} \
--labels=True --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000 \
--labels_v3=True --y_bins=({genres},{artists}) --max_bow_genre_size=1 --min_duration=60.0 --max_duration=600.0 --t_bins=64
```

#### With lyrics
For sampling, follow same instructions as [above](#sample-from-new-model) but use `small_labelled_prior` instead of `small_prior`.

#### Train with lyrics
To train in addition with lyrics, update `get_metadata` in `data/files_dataset.py` to return `lyrics` too.
For training with lyrics, we'll use `small_single_enc_dec_prior` in `hparams.py`.
- Lyrics:
- For each file, we linearly align the lyric characters to the audio, find the position in lyric that corresponds to the midpoint of our audio chunk, and pass a window of `n_tokens` lyric characters centred around that.
- Pass the hps `use_tokens=True` and `n_tokens` to be the number of lyric characters to use for an audio chunk. Set it according to the `sample_length` you're training on so that its large enough that the lyrics for an audio chunk are almost always found inside a window of that size.
- If you use a non-English vocabulary, update `text_processor.py` with your new vocab and pass `n_vocab = number of characters in vocabulary` accordingly. In v2, we had a `n_vocab=80` and in v3 we missed `+` and so `n_vocab=79` of characters.
- For each file, we linearly align the lyric characters to the audio, find the position in lyric that corresponds to
the midpoint of our audio chunk, and pass a window of `n_tokens` lyric characters centred around that.
- In `small_single_enc_dec_prior`, set the hps `use_tokens=True` and `n_tokens` to be the number of lyric characters
to use for an audio chunk. Set it according to the `sample_length` you're training on so that its large enough that
the lyrics for an audio chunk are almost always found inside a window of that size.
- If you use a non-English vocabulary, update `text_processor.py` with your new vocab and set
`n_vocab = number of characters in vocabulary` accordingly in `small_single_enc_dec_prior`. In v2, we had a `n_vocab=80`
and in v3 we missed `+` and so `n_vocab=79` of characters.

After these modifications, to train a top-level with labels and lyrics, run
```
mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,small_single_enc_dec_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_single_enc_dec_prior_labels \
--sample_length=786432 --bs=4 --aug_shift --aug_blend --audio_files_dir={audio_files_dir} \
--labels=True --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000 \
--labels_v3=True --y_bins=({genres},{artists}) --max_bow_genre_size=1 --min_duration=60.0 --max_duration=600.0 --t_bins=64 \
--use_tokens=True --n_tokens=384 --n_vocab=79
```
To simplify hps choices, here we used a `single_enc_dec` model like the `1b_lyrics` model that combines both encoder and
decoder of the transformer into a single model. We do so by merging the lyric vocab and vq-vae vocab into a single
Expand Down
6 changes: 3 additions & 3 deletions jukebox/data/files_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def init_dataset(self, hps):
if self.labels:
self.labeller = Labeller(hps.max_bow_genre_size, hps.n_tokens, self.sample_length, v3=hps.labels_v3)

self.t_ranges = ((self.min_duration*self.sr, self.max_duration*self.sr), # Total length
(0.0,self.max_duration*self.sr), # Absolute pos
(0.0,1.0)) # Relative pos
self.t_ranges = hps.t_ranges = ((self.min_duration*self.sr, self.max_duration*self.sr), # Total length
(0.0,self.max_duration*self.sr), # Absolute pos
(0.0,1.0)) # Relative pos


def get_index_offset(self, item):
Expand Down
34 changes: 31 additions & 3 deletions jukebox/hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,18 @@ def setup_hparams(hparam_set_names, kwargs):
)
HPARAMS_REGISTRY["small_prior"] = small_prior

small_labelled_prior = Hyperparams(
labels=True,
labels_v3=True,
y_bins=(10,100), # Set this to (genres, artists) for your dataset
max_bow_genre_size=1,
min_duration=60.0,
max_duration=600.0,
t_bins=64,
)
small_labelled_prior.update(small_prior)
HPARAMS_REGISTRY["small_labelled_prior"] = small_labelled_prior

small_single_enc_dec_prior = Hyperparams(
n_ctx=6144,
prior_width=1024,
Expand All @@ -226,10 +238,18 @@ def setup_hparams(hparam_set_names, kwargs):
blocks=64,
init_scale=0.7,
c_res=1,
use_tokens=True,
n_tokens=384,
prime_loss_fraction=0.4,
single_enc_dec=True,
labels=True,
labels_v3=True,
y_bins=(10,100), # Set this to (genres, artists) for your dataset
max_bow_genre_size=1,
min_duration=60.0,
max_duration=600.0,
t_bins=64,
use_tokens=True,
n_tokens=384,
n_vocab=79,
)
HPARAMS_REGISTRY["small_single_enc_dec_prior"] = small_single_enc_dec_prior

Expand All @@ -249,9 +269,17 @@ def setup_hparams(hparam_set_names, kwargs):
prime_blocks=32,
prime_init_scale=0.7,
prime_c_res=1,
prime_loss_fraction=0.4,
labels=True,
labels_v3=True,
y_bins=(10,100), # Set this to (genres, artists) for your dataset
max_bow_genre_size=1,
min_duration=60.0,
max_duration=600.0,
t_bins=64,
use_tokens=True,
n_tokens=384,
prime_loss_fraction=0.4,
n_vocab=79,
)
HPARAMS_REGISTRY["small_sep_enc_dec_prior"] = small_sep_enc_dec_prior

Expand Down
5 changes: 5 additions & 0 deletions jukebox/make_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def make_prior(hps, vqvae, device='cuda'):
dilation_growth_rate=hps.cond_dilation_growth_rate, dilation_cycle=hps.cond_dilation_cycle,
zero_out=hps.cond_zero_out, res_scale=hps.cond_res_scale,
checkpoint_res=hps.cond_c_res) # have to keep this else names wrong
if hps.labels and hps.t_ranges == None:
print_all("Setting t_ranges from min/max duration")
hps.t_ranges = ((hps.min_duration * hps.sr, hps.max_duration * hps.sr), # Total length
(0.0, hps.max_duration * hps.sr), # Absolute pos
(0.0, 1.0)) # Relative pos
y_cond_kwargs = dict(out_width=hps.prior_width, init_scale=hps.init_scale,
y_bins=hps.y_bins, t_bins=hps.t_bins, t_ranges=hps.t_ranges,
max_bow_genre_size=hps.max_bow_genre_size)
Expand Down

0 comments on commit 534e5d3

Please sign in to comment.