Skip to content

kuleshov/diffusion-llm

 
 

Repository files navigation

By Subham Sekhar Sahoo, Marianne Arriola, Yair Schiff, Aaron Gokaslan, Edgar Marroquin, Justin T Chiu, Alexander Rush, Volodymyr Kuleshov

Open In Colab deploy arXiv deploy

graphical_abstract_updated_2

This is an experimental fork of the main MDLM repo. This code is experimental, may be broken, and is being actively hacked on as a personal experiment. Please use the official repo for anything serious.

Code Organization

  1. main.py: Routines for training and evaluation
  2. noise_schedule.py: Noise schedules
  3. diffusion.py: Forward/reverse diffusion
  4. dataloader.py: Dataloaders
  5. utils.py: LR scheduler, logging, fsspec handling
  6. models/: Denoising network architectures. Supports DiT, AR transformer, and Mamba
  7. configs/: Config files for datasets/denoising networks/noise schedules/LR schedules
  8. scripts/: Shell scripts for training/evaluation

Getting started in this repository

To get started, create a conda environment containing the required dependencies.

conda env create -f requirements.yaml
conda activate mdlm

Create the following directories to store saved models and slurm logs:

mkdir outputs
mkdir watch_folder

and run the training as a batch job:

sbatch scripts/train_owt_mdlm.sh

Checkpoints

We have uploaded MDLM model trained on OpenWebText for 1M training steps to the Huggingface hub 🤗: kuleshov-group/mdlm-owt Furthermore, we have released the checkpoints for the AR and SEDD baselines trained on OpenWebText in this Google Drive folder.

Reproducing Experiments

Below, we describe the steps required for reproducing the experiments in the paper. Throughout, the main entry point for running experiments is the main.py script. We also provide sample slurm scripts for launching pre-training and downstream fine-tuning experiments in the scrips/ directory.

Generate Samples

The argument to sampling.predictor specifies the sampler which takes one of the following values:

  • ddpm_cache: our proposed sampler that's ~3-4x faster than the samplers propsed in D3PM and SEDD.
  • ddpm: Ancestral sampling proposed in D3PM.
  • analytic: Analytic sampler proposed in SEDD.

To generate samples from a pre-trained model use one of the following commands:

Huggingface model

python main.py \
  mode=sample_eval \
  eval.checkpoint_path=kuleshov-group/mdlm-owt \
  data=openwebtext-split  \
  model.length=1024  \
  sampling.predictor=ddpm_cache  \
  sampling.steps=1000 \
  loader.eval_batch_size=1 \
  sampling.num_sample_batches=10 \
  backbone=hf_dit

Local checkpoint

python main.py \
  mode=sample_eval \
  eval.checkpoint_path=/path/to/checkpoint/mdlm.ckpt \
  data=openwebtext-split  \
  model.length=1024  \
  sampling.predictor=ddpm_cache  \
  sampling.steps=10000 \
  loader.eval_batch_size=1 \
  sampling.num_sample_batches=1 \
  backbone=dit

Citation

@misc{sahoo2024simple,
      title={Simple and Effective Masked Diffusion Language Models}, 
      author={Subham Sekhar Sahoo and Marianne Arriola and Yair Schiff and Aaron Gokaslan and Edgar Marroquin and Justin T Chiu and Alexander Rush and Volodymyr Kuleshov},
      year={2024},
      eprint={2406.07524},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

About

Simplified Masked Diffusion Language Model

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 95.6%
  • Shell 4.4%