Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
prafullasd committed May 14, 2020
1 parent aeba9eb commit 51b02fb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ To train with you own metadata for your audio files, pass `--labels=True --label
We use 2 kinds of labels information:
- Artist/Genre:
- For each file, we return an artist_id and a list of genre_ids. The reason we have a list and not a single genre_id is that in v2, we split genres like `blues_rock` into a bag of words `[blues, rock]`, and we pass atmost `max_bow_genre_size` of those, in `v3` we consider it as a single word and just set `max_bow_genre_size=1`.
- Update the `v3_artist_ids` and `v3_genre_ids` to use ids from your new dataset. Pass the hps `y_bins = (number_of_artists, number_of_genres)` and `max_bow_genre_size=1`.
- Update the `v3_artist_ids` and `v3_genre_ids` to use ids from your new dataset. Pass the hps `y_bins = (number_of_genres, number_of_artists)` and `max_bow_genre_size=1`.
- Timing:
- For each chunk of audio, we return the `total_length` of the song, the `offset` the current audio chunk is at and the `sample_length` of the audio chunk. We have three timing embeddings: total_length, our current position, and our current position as a fraction of the total length, and we divide the range of these values into `t_bins` discrete bins.
- Pass the hps `min_duration` and `max_duration` to be the shortest/longest duration of audio files, and `t_bins` for how many bins you want to discretize timing information into.
Expand All @@ -146,7 +146,7 @@ After these modifications, to train a top-level with labels, run
mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,small_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_prior_labels \
--sample_length=1048576 --bs=4 --aug_shift --aug_blend --audio_files_dir={audio_files_dir} \
--labels=True --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000 \
--labels_v3=True --y_bins=({artists},{genres}) --max_bow_genre_size=1 --min_duration=60.0 --max_duration=600.0 --t_bins=64
--labels_v3=True --y_bins=({genres},{artists}) --max_bow_genre_size=1 --min_duration=60.0 --max_duration=600.0 --t_bins=64
```

#### With lyrics
Expand All @@ -161,7 +161,7 @@ After these modifications, to train a top-level with labels and lyrics, run
mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,small_single_enc_dec_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_single_enc_dec_prior_labels \
--sample_length=786432 --bs=4 --aug_shift --aug_blend --audio_files_dir={audio_files_dir} \
--labels=True --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000 \
--labels_v3=True --y_bins=({artists},{genres}) --max_bow_genre_size=1 --min_duration=60.0 --max_duration=600.0 --t_bins=64 \
--labels_v3=True --y_bins=({genres},{artists}) --max_bow_genre_size=1 --min_duration=60.0 --max_duration=600.0 --t_bins=64 \
--use_tokens=True --n_tokens=384 --n_vocab=79
```
To simplify hps choices, here we used a `single_enc_dec` model like the `1b_lyrics` model that combines both encoder and
Expand Down
6 changes: 4 additions & 2 deletions jukebox/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def log_labels(logger, labeller, tag, y, hps):
y = y.cpu().numpy()
txt = f''
for item in range(y.shape[0]):
txt += labeller.describe_label(y)
description = labeller.describe_label(y)
artist, genre, lyrics = description['artist'], description['genre'], description['lyrics']
txt += f'{item} artist:{artist}, genre:{genre}, lyrics:{lyrics}'
logger.add_text(tag, txt)
logger.flush()

Expand Down Expand Up @@ -138,7 +140,7 @@ def sample_prior(orig_model, ema, logger, x_in, y, hps):

# Recons
for i in range(len(x_ds)):
log_aud(logger, f'sample_x_ds_start_{i}', x_ds[i], hps)
log_aud(logger, f'x_ds_start_{i}', x_ds[i], hps)
orig_model.train()
if ema is not None: ema.swap()
logger.flush()
Expand Down

0 comments on commit 51b02fb

Please sign in to comment.