Skip to content

homerjed/sbgm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

68 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

sbgm

Score-Based Diffusion Models in JAX

Implementation and extension of

and

in jax and equinox.

Warning

🏗️ Note this repository is under construction, expect changes. 🏗️

Score-based diffusion models

Diffusion models are deep hierarchical models for data that use neural networks to model the reverse of a diffusion process that adds a sequence of noise perturbations to the data.

Modern cutting-edge diffusion models (see citations) express both the forward and reverse diffusion processes as a Stochastic Differential Equation (SDE).


A diagram (see citations) showing how to map data to a noise distribution (the prior) with an SDE, and reverse this SDE for generative modeling. One can also reverse the associated probability flow ODE, which yields a deterministic process that samples from the same distribution as the SDE. Both the reverse-time SDE and probability flow ODE can be obtained by estimating the score.


For any SDE of the form

$$ \text{d}\boldsymbol{x} = f(\boldsymbol{x}, t)\text{d}t + g(t)\text{d}\boldsymbol{w}, $$

the reverse of the SDE from noise to data is given by

$$ \text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t) - g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t + g(t)\text{d}\boldsymbol{w}. $$

For every SDE there exists an associated ordinary differential equation (ODE)

$$ \text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t, $$

where the trajectories of the SDE and ODE have the same marginal PDFs $p_t(\boldsymbol{x})$.

The Stein score of the marginal probability distributions over $t$ is approximated with a neural network $\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})\approx s_{\theta}(\boldsymbol{x}(t), t)$. The parameters of the neural network are fit by minimising the score-matching loss.

Computing log-likelihoods with diffusion models

For each SDE there exists a deterministic ODE with marginal likelihoods $p_t(\boldsymbol{x})$ that match the SDE for all time $t$

$$ \text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t = F(\boldsymbol{x}(t), t). $$

The continuous normalizing flow formalism allows the ODE to be expressed as

$$ \frac{\partial}{\partial t} \log p(\boldsymbol{x}(t)) = -\text{Tr}\bigg [ \frac{\partial}{\partial \boldsymbol{x}(t)} F(\boldsymbol{x}(t), t) \bigg ], $$

but note that maximum-likelihood training is prohibitively expensive for SDE based diffusion models.

Usage

Install via

pip install sbgm

to run

python main.py

See examples.

To run cifar10, try something like

import sbgm
import data
import configs

datasets_path = "."
root_dir = "."

config = configs.cifar10_config()

key = jr.key(config.seed)
data_key, model_key, train_key = jr.split(key, 3)

dataset = data.cifar10(datasets_path, data_key)

sharding = sbgm.shard.get_sharding()
    
# Diffusion model 
model = sbgm.models.get_model(
    model_key, 
    config.model.model_type, 
    dataset.data_shape, 
    dataset.context_shape, 
    dataset.parameter_dim,
    config
)

# Stochastic differential equation (SDE)
sde = sbgm.sde.get_sde(config.sde)

# Fit model to dataset
model = sbgm.train.train(
    train_key,
    model,
    sde,
    dataset,
    config,
    reload_opt_state=False,
    sharding=sharding,
    save_dir=root_dir
)

Features

  • Parallelised exact and approximate log-likelihood calculations,
  • UNet and transformer score network implementations,
  • VP, SubVP and VE SDEs (neural network $\beta(t)$ and $\sigma(t)$ functions are on the list!),
  • Multi-modal conditioning (basically just optional parameter and image conditioning methods),
  • Checkpointing optimiser and model,
  • Multi-device training and sampling.

Samples

Note

I haven't optimised any training/architecture hyperparameters or trained long enough here, you could do a lot better.

Flowers

Euler-Marayama sampling Flowers Euler-Marayama sampling

ODE sampling Flowers ODE sampling

CIFAR10

Euler-Marayama sampling CIFAR10 Euler-marayama sampling

ODE sampling CIFAR10 ODE sampling

SDEs

alt text

Citations

@misc{song2021scorebasedgenerativemodelingstochastic,
      title={Score-Based Generative Modeling through Stochastic Differential Equations}, 
      author={Yang Song and Jascha Sohl-Dickstein and Diederik P. Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
      year={2021},
      eprint={2011.13456},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2011.13456}, 
}
@misc{song2021maximumlikelihoodtrainingscorebased,
      title={Maximum Likelihood Training of Score-Based Diffusion Models}, 
      author={Yang Song and Conor Durkan and Iain Murray and Stefano Ermon},
      year={2021},
      eprint={2101.09258},
      archivePrefix={arXiv},
      primaryClass={stat.ML},
      url={https://arxiv.org/abs/2101.09258}, 
}