By Subham Sekhar Sahoo, Marianne Arriola, Yair Schiff, Aaron Gokaslan, Edgar Marroquin, Justin T Chiu, Alexander Rush, Volodymyr Kuleshov
This is an experimental fork of the main MDLM repo. This code is experimental, may be broken, and is being actively hacked on as a personal experiment. Please use the official repo for anything serious.
main.py
: Routines for training and evaluationnoise_schedule.py
: Noise schedulesdiffusion.py
: Forward/reverse diffusiondataloader.py
: Dataloadersutils.py
: LR scheduler, logging,fsspec
handlingmodels/
: Denoising network architectures. Supports DiT, AR transformer, and Mambaconfigs/
: Config files for datasets/denoising networks/noise schedules/LR schedulesscripts/
: Shell scripts for training/evaluation
To get started, create a conda environment containing the required dependencies.
conda env create -f requirements.yaml
conda activate mdlm
Create the following directories to store saved models and slurm logs:
mkdir outputs
mkdir watch_folder
and run the training as a batch job:
sbatch scripts/train_owt_mdlm.sh
We have uploaded MDLM model trained on OpenWebText for 1M training steps to the Huggingface hub 🤗: kuleshov-group/mdlm-owt Furthermore, we have released the checkpoints for the AR and SEDD baselines trained on OpenWebText in this Google Drive folder.
Below, we describe the steps required for reproducing the experiments in the paper.
Throughout, the main entry point for running experiments is the main.py
script.
We also provide sample slurm
scripts for launching pre-training and downstream fine-tuning experiments in the scrips/
directory.
The argument to sampling.predictor
specifies the sampler which takes one of the following values:
ddpm_cache
: our proposed sampler that's ~3-4x faster than the samplers propsed in D3PM and SEDD.ddpm
: Ancestral sampling proposed in D3PM.analytic
: Analytic sampler proposed in SEDD.
To generate samples from a pre-trained model use one of the following commands:
python main.py \
mode=sample_eval \
eval.checkpoint_path=kuleshov-group/mdlm-owt \
data=openwebtext-split \
model.length=1024 \
sampling.predictor=ddpm_cache \
sampling.steps=1000 \
loader.eval_batch_size=1 \
sampling.num_sample_batches=10 \
backbone=hf_dit
python main.py \
mode=sample_eval \
eval.checkpoint_path=/path/to/checkpoint/mdlm.ckpt \
data=openwebtext-split \
model.length=1024 \
sampling.predictor=ddpm_cache \
sampling.steps=10000 \
loader.eval_batch_size=1 \
sampling.num_sample_batches=1 \
backbone=dit
@misc{sahoo2024simple,
title={Simple and Effective Masked Diffusion Language Models},
author={Subham Sekhar Sahoo and Marianne Arriola and Yair Schiff and Aaron Gokaslan and Edgar Marroquin and Justin T Chiu and Alexander Rush and Volodymyr Kuleshov},
year={2024},
eprint={2406.07524},
archivePrefix={arXiv},
primaryClass={cs.CL}
}