Skip to content

[Neurips 2023] Official pytorch implementation of "Addressing Negative Transfer in Diffusion Models"

License

Notifications You must be signed in to change notification settings

gohyojun15/ANT_diffusion

Repository files navigation

[Neurips 2023] Addressing Negative Transfer in Diffusion Models.

Alt text Generated image from ANT-UW (DiT-L) with guidance scale 3.0

This repository contains the official PyTorch implementation of the following paper: "Addressing Negative Transfer in Diffusion Models" (Neurips 2023). To gain a better understanding of the paper, please visit our project page and paper on arXiv.

This code is to train DiT model with ANT on ImageNet dataset. Our implementation is based on DiT, LibMTL, NashMTL.

Updates

  • 2024.07.04: Our project Switch Diffusion Transformer utilizing ANT-UW has been accepted in ECCV2024!.
  • 2024.05.31: Uncertainty weighting is utilized in EDM2, a concept we had previously explored in our work. We believe that rethinking diffusion models as multi-task learners is an effective direction to improve these models!
  • 2023.10.07: DTR uses ANT-UW and shows that ANT-UW outperforms P2weight and MinSNR.
  • 2023.11.05: Initial release.

Install pre-requisites

$ pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
$ pip install -r requirements.txt

Training DiT with ANT.

  • You can train DiT-S model with ANT

    python train_ant_single_gpu.py \
    --data_path [Your Data Path] \
    --total_clusters [Your cluster size (Default: 5)] \
    --mtl_method [uw, nash, or pcgrad]
    
  • Train DiT-L model with ANT-UW using multiple-gpus.

    torchrun --nproc_per_node=[number of gpus] train_uw_multi_gpu.py \
    --data_path [Your Data Path] \
    --total_clusters [Your cluster size (Default: 5)]
    

Sampling DiT trained with ANT.

torchrun --nproc_per_node=[number of gpus] sample_ddp.py \
--model_config [Your config path (config/DiT-L.yaml, config/DiT-S.yaml)] \
--ckpt [Your ckpt path] \
--sample-dir [Your sample dir] 

Interval clustering

The example codes for interval clustering are shown in interval_clustering.py

Model card ($k$=5)

All models can be downloaded from OneDrive link

Model FID IS Precision Recall
DiT-L + ANT-UW (Multi-GPU) 5.695 186.661 0.811 0.491
DiT-S + Nash 44.65 33.48 0.4209 0.5272
DiT-S + UW 48.40 30.84 0.4196 0.5172

About

[Neurips 2023] Official pytorch implementation of "Addressing Negative Transfer in Diffusion Models"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages