Skip to content

Commit

Permalink
MAGNeT v1 (facebookresearch#387)
Browse files Browse the repository at this point in the history
* MAGNeT v1 release
* Version bump + typos
  • Loading branch information
lonzi authored Jan 15, 2024
1 parent 80540a1 commit 905371a
Show file tree
Hide file tree
Showing 28 changed files with 2,573 additions and 321 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [1.3.0a] - TBD

Adding the MAGNeT model (https://arxiv.org/abs/2401.04577) along with hf checkpoints and a gradio demo app.

Typo fixes.

## [1.2.0] - 2024-01-11

Adding stereo models.
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ At the moment, AudioCraft contains the training code and inference code for:
* [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model.
* [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec.
* [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion.
* [MAGNeT](./docs/MAGNET.md): A state-of-the-art non-autoregressive model for text-to-music and text-to-sound.

## Training code

Expand Down
2 changes: 1 addition & 1 deletion audiocraft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
# flake8: noqa
from . import data, modules, models

__version__ = '1.2.0'
__version__ = '1.3.0a'
6 changes: 6 additions & 0 deletions audiocraft/grids/magnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""MAGNeT grids."""
32 changes: 32 additions & 0 deletions audiocraft/grids/magnet/audio_magnet_16khz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from ..musicgen._explorers import LMExplorer
from ...environment import AudioCraftEnvironment


@LMExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=32, partition=partitions)
launcher.bind_(solver='magnet/audio_magnet_16khz')
# replace this by the desired environmental sound dataset
launcher.bind_(dset='internal/sounds_16khz')

fsdp = {'autocast': False, 'fsdp.use': True}
medium = {'model/lm/model_scale': 'medium'}

# Small model (300M)
launcher.slurm_(gpus=32).bind_(label='32gpus')
with launcher.job_array():
sub = launcher.bind()
sub()

# Medium model (1.5B)
launcher.slurm_(gpus=64).bind_(label='64gpus')
with launcher.job_array():
sub = launcher.bind()
sub(medium, fsdp)
74 changes: 74 additions & 0 deletions audiocraft/grids/magnet/audio_magnet_pretrained_16khz_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Evaluation with objective metrics for the pretrained audio-MAGNeT models.
This grid takes signature from the training grid and runs evaluation-only stage.
When running the grid for the first time, please use:
REGEN=1 dora grid magnet.audio_magnet_pretrained_16khz_eval
and re-use the REGEN=1 option when the grid is changed to force regenerating it.
Note that you need the proper metrics external libraries setup to use all
the objective metrics activated in this grid. Refer to the README for more information.
"""

import os

from ..musicgen._explorers import GenerationEvalExplorer
from ...environment import AudioCraftEnvironment
from ... import train


def eval(launcher, batch_size: int = 32):
opts = {
'dset': 'audio/audiocaps_16khz',
'solver/audiogen/evaluation': 'objective_eval',
'execute_only': 'evaluate',
'+dataset.evaluate.batch_size': batch_size,
'+metrics.fad.tf.batch_size': 32,
}
# binary for FAD computation: replace this path with your own path
metrics_opts = {
'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
}

sub = launcher.bind(opts)
sub.bind_(metrics_opts)

# base objective metrics
sub()


@GenerationEvalExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=4, partition=partitions)

if 'REGEN' not in os.environ:
folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
with launcher.job_array():
for sig in folder.iterdir():
if not sig.is_symlink():
continue
xp = train.main.get_xp_from_sig(sig.name)
launcher(xp.argv)
return

with launcher.job_array():
audio_magnet = launcher.bind(solver="magnet/audio_magnet_16khz")

fsdp = {'autocast': False, 'fsdp.use': True}

# Small audio-MAGNeT model (300M)
audio_magnet_small = audio_magnet.bind({'continue_from': '//pretrained/facebook/audio-magnet-small'})
eval(audio_magnet_small, batch_size=128)

# Medium audio-MAGNeT model (1.5B)
audio_magnet_medium = audio_magnet.bind({'continue_from': '//pretrained/facebook/audio-magnet-medium'})
audio_magnet_medium.bind_({'model/lm/model_scale': 'medium'})
audio_magnet_medium.bind_(fsdp)
eval(audio_magnet_medium, batch_size=128)
47 changes: 47 additions & 0 deletions audiocraft/grids/magnet/magnet_32khz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from ..musicgen._explorers import LMExplorer
from ...environment import AudioCraftEnvironment


@LMExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=32, partition=partitions)
launcher.bind_(solver='magnet/magnet_base_32khz')
# replace this by the desired music dataset
launcher.bind_(dset='internal/music_400k_32khz')

fsdp = {'autocast': False, 'fsdp.use': True}
medium = {'model/lm/model_scale': 'medium'}
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
segdur_10secs = {'dataset.segment_duration': 10,
'dataset.batch_size': 576,
'generate.lm.decoding_steps': [20, 10, 10, 10]}

# Small models (300M)
launcher.slurm_(gpus=32).bind_(label='32gpus')
with launcher.job_array():
# 30 seconds
sub = launcher.bind()
sub()

# 10 seconds
sub = launcher.bind()
sub(segdur_10secs)

# Medium models (1.5B)
launcher.bind_(fsdp)
launcher.slurm_(gpus=64).bind_(label='64gpus')
with launcher.job_array():
# 30 seconds
sub = launcher.bind()
sub(medium, adam)

# 10 seconds
sub = launcher.bind()
sub(segdur_10secs)
87 changes: 87 additions & 0 deletions audiocraft/grids/magnet/magnet_pretrained_32khz_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Evaluation with objective metrics for the pretrained MAGNeT models.
This grid takes signature from the training grid and runs evaluation-only stage.
When running the grid for the first time, please use:
REGEN=1 dora grid magnet.magnet_pretrained_32khz_eval
and re-use the REGEN=1 option when the grid is changed to force regenerating it.
Note that you need the proper metrics external libraries setup to use all
the objective metrics activated in this grid. Refer to the README for more information.
"""

import os

from ..musicgen._explorers import GenerationEvalExplorer
from ...environment import AudioCraftEnvironment
from ... import train


def eval(launcher, batch_size: int = 32):
opts = {
'dset': 'audio/musiccaps_32khz',
'solver/musicgen/evaluation': 'objective_eval',
'execute_only': 'evaluate',
'+dataset.evaluate.batch_size': batch_size,
'+metrics.fad.tf.batch_size': 16,
}
# binary for FAD computation: replace this path with your own path
metrics_opts = {
'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
}

sub = launcher.bind(opts)
sub.bind_(metrics_opts)

# base objective metrics
sub()


@GenerationEvalExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=4, partition=partitions)

if 'REGEN' not in os.environ:
folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
with launcher.job_array():
for sig in folder.iterdir():
if not sig.is_symlink():
continue
xp = train.main.get_xp_from_sig(sig.name)
launcher(xp.argv)
return

with launcher.job_array():
magnet = launcher.bind(solver="magnet/magnet_32khz")

fsdp = {'autocast': False, 'fsdp.use': True}

segdur_10secs = {'dataset.segment_duration': 10,
'generate.lm.decoding_steps': [20, 10, 10, 10]}

# 10-second magnet models
magnet_small_10secs = magnet.bind({'continue_from': '//pretrained/facebook/magnet-small-10secs'})
magnet_small_10secs.bind_(segdur_10secs)
eval(magnet_small_10secs, batch_size=128)

magnet_medium_10secs = magnet.bind({'continue_from': '//pretrained/facebook/magnet-medium-10secs'})
magnet_medium_10secs.bind_(segdur_10secs)
magnet_medium_10secs.bind_({'model/lm/model_scale': 'medium'})
magnet_medium_10secs.bind_(fsdp)
eval(magnet_medium_10secs, batch_size=128)

# 30-second magnet models
magnet_small_30secs = magnet.bind({'continue_from': '//pretrained/facebook/magnet-small-30secs'})
eval(magnet_small_30secs, batch_size=128)

magnet_medium_30secs = magnet.bind({'continue_from': '//pretrained/facebook/magnet-medium-30secs'})
magnet_medium_30secs.bind_({'model/lm/model_scale': 'medium'})
magnet_medium_30secs.bind_(fsdp)
eval(magnet_medium_30secs, batch_size=128)
2 changes: 2 additions & 0 deletions audiocraft/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
HFEncodecModel, HFEncodecCompressionModel)
from .audiogen import AudioGen
from .lm import LMModel
from .lm_magnet import MagnetLMModel
from .multibanddiffusion import MultiBandDiffusion
from .musicgen import MusicGen
from .magnet import MAGNeT
from .unet import DiffusionUnet
Loading

0 comments on commit 905371a

Please sign in to comment.