Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support LCM #96

Merged
merged 2 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Support LCM
  • Loading branch information
okotaku committed Nov 22, 2023
commit 4da7bd18bacd310a0561cf894688dc4f5726ad36
71 changes: 39 additions & 32 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -1,35 +1,42 @@
{
"dockerComposeFile": ["../docker-compose.yml"],
"service": "diffengine",
"workspaceFolder": "/workspace",
"postCreateCommand": "pre-commit install",
"customizations": {
"vscode": {
"extensions": [
"njpwerner.autodocstring",
"ms-python.black-formatter",
"ms-vscode-remote.remote-containers",
"ms-azuretools.vscode-docker",
"usernamehw.errorlens",
"GitHub.copilot",
"GitHub.copilot-chat",
"GitHub.vscode-pull-request-github",
"ms-toolsai.jupyter",
"ms-toolsai.vscode-jupyter-cell-tags",
"ms-toolsai.jupyter-keymap",
"ms-toolsai.jupyter-renderers",
"ms-toolsai.vscode-jsupyter-slideshow",
"ms-python.vscode-pylance",
"ms-python.python",
"KevinRose.vsc-python-indent",
"ms-vscode-remote.remote-ssh",
"ms-vscode-remote.remote-ssh-edit",
"ms-vscode.remote-explorer",
"wayou.vscode-todo-highlight",
"Gruntfuggly.todo-tree",
"streetsidesoftware.code-spell-checker",
"charliermarsh.ruff"
]
}
"dockerComposeFile": [
"../docker-compose.yml"
],
"service": "diffengine",
"workspaceFolder": "/workspace",
"postCreateCommand": "pre-commit install",
"customizations": {
"vscode": {
"extensions": [
"njpwerner.autodocstring",
"ms-python.black-formatter",
"ms-vscode-remote.remote-containers",
"ms-azuretools.vscode-docker",
"usernamehw.errorlens",
"GitHub.copilot",
"GitHub.copilot-chat",
"GitHub.vscode-pull-request-github",
"ms-toolsai.jupyter",
"ms-toolsai.vscode-jupyter-cell-tags",
"ms-toolsai.jupyter-keymap",
"ms-toolsai.jupyter-renderers",
"ms-toolsai.vscode-jsupyter-slideshow",
"ms-python.vscode-pylance",
"ms-python.python",
"KevinRose.vsc-python-indent",
"ms-vscode-remote.remote-ssh",
"ms-vscode-remote.remote-ssh-edit",
"ms-vscode.remote-explorer",
"wayou.vscode-todo-highlight",
"Gruntfuggly.todo-tree",
"streetsidesoftware.code-spell-checker",
"charliermarsh.ruff",
"github.vscode-github-actions",
"tamasfe.even-better-toml",
"oderwat.indent-rainbow",
"yzhang.markdown-all-in-one",
"ionutvmi.path-autocomplete"
]
}
}
}
35 changes: 24 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@

## 📄 Table of Contents

- [📖 Introduction](#-introduction-)
- [🛠️ Installation](#-installation-)
- [👨‍🏫 Get Started](#-get-started-)
- [🖋 Example Notebook](#-example-notebook-)
- [📘 Documentation](#-documentation-)
- [📊 Model Zoo](#-model-zoo-)
- [🙌 Contributing](#-contributing-)
- [🎫 License](#-license-)
- [🖊️ Citation](#-citation-)
- [💻 Sponsors](#-sponsors-)
- [🤝 Acknowledgement](#-acknowledgement-)
- [DiffEngine](#diffengine)
- [📄 Table of Contents](#-table-of-contents)
- [📖 Introduction 🔝](#-introduction-)
- [🛠️ Installation 🔝](#️-installation-)
- [👨‍🏫 Get Started 🔝](#-get-started-)
- [🖋 Example Notebook 🔝](#-example-notebook-)
- [📘 Documentation 🔝](#-documentation-)
- [📊 Model Zoo 🔝](#-model-zoo-)
- [🙌 Contributing 🔝](#-contributing-)
- [🎫 License 🔝](#-license-)
- [🖊️ Citation 🔝](#️-citation-)
- [💻 Sponsors](#-sponsors)
- [🤝 Acknowledgement 🔝](#-acknowledgement-)

## 📖 Introduction [🔝](#-table-of-contents)

Expand Down Expand Up @@ -134,6 +136,8 @@ For detailed user guides and advanced guides, please refer to our [Documentation
- [Run InstructPix2Pix](https://diffengine.readthedocs.io/en/latest/run_guides/run_instruct_pix2pix.html)
- [Run Wuerstchen](https://diffengine.readthedocs.io/en/latest/run_guides/run_wuerstchen.html)
- [Run Wuerstchen LoRA](https://diffengine.readthedocs.io/en/latest/run_guides/run_wuerstchen_lora.html)
- [Run LCM XL](https://diffengine.readthedocs.io/en/latest/run_guides/run_lcm.html)
- [Run LCM XL LoRA](https://diffengine.readthedocs.io/en/latest/run_guides/run_lcm_lora.html)
- [Inference](https://diffengine.readthedocs.io/en/latest/run_guides/inference.html)

</details>
Expand Down Expand Up @@ -231,6 +235,9 @@ For detailed user guides and advanced guides, please refer to our [Documentation
<td>
<b>Wuerstchen</b>
</td>
<td>
<b>Latent Consistency Models</b>
</td>
</tr>
<tr valign="top">
<td>
Expand All @@ -239,6 +246,12 @@ For detailed user guides and advanced guides, please refer to our [Documentation
<li><a href="configs/wuerstchen_lora/README.md">LoRA (ICLR'2022)</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/lcm/README.md">Latent Consistency Models (2023)</a></li>
<li><a href="configs/lcm_lora/README.md">LoRA (ICLR'2022)</a></li>
</ul>
</td>
</tr>
</td>
</tr>
Expand Down
7 changes: 7 additions & 0 deletions configs/_base_/models/lcm_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
model = dict(
type="LatentConsistencyModelsXL",
model="stabilityai/stable-diffusion-xl-base-1.0",
vae_model="madebyollin/sdxl-vae-fp16-fix",
loss=dict(type="HuberLoss"),
pre_compute_text_embeddings=True,
gradient_checkpointing=True)
27 changes: 27 additions & 0 deletions configs/_base_/models/lcm_xl_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
model = dict(
type="LatentConsistencyModelsXL",
model="stabilityai/stable-diffusion-xl-base-1.0",
vae_model="madebyollin/sdxl-vae-fp16-fix",
loss=dict(type="HuberLoss"),
pre_compute_text_embeddings=True,
gradient_checkpointing=True,
unet_lora_config=dict(
type="LoRA",
r=8,
lora_alpha=1,
target_modules=[
"to_q",
"to_k",
"to_v",
"to_out.0",
"proj_in",
"proj_out",
"ff.net.0.proj",
"ff.net.2",
"conv1",
"conv2",
"conv_shortcut",
"downsamplers.0.conv",
"upsamplers.0.conv",
"time_emb_proj",
]))
17 changes: 17 additions & 0 deletions configs/_base_/schedules/lcm_xl_50e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
optim_wrapper = dict(
type="AmpOptimWrapper",
dtype="float16",
optimizer=dict(type="AdamW8bit", lr=1e-6, weight_decay=0.0),
clip_grad=dict(max_norm=1.0))

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=50)
val_cfg = None
test_cfg = None

default_hooks = dict(
checkpoint=dict(
type="CheckpointHook",
interval=1,
max_keep_ckpts=3,
))
89 changes: 89 additions & 0 deletions configs/lcm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Latent Consistency Models

[Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378)

## Abstract

Latent Diffusion models (LDMs) have achieved remarkable results in synthesizing high-resolution images. However, the iterative sampling process is computationally intensive and leads to slow generation. Inspired by Consistency Models (song et al.), we propose Latent Consistency Models (LCMs), enabling swift inference with minimal steps on any pre-trained LDMs, including Stable Diffusion (rombach et al). Viewing the guided reverse diffusion process as solving an augmented probability flow ODE (PF-ODE), LCMs are designed to directly predict the solution of such ODE in latent space, mitigating the need for numerous iterations and allowing rapid, high-fidelity sampling. Efficiently distilled from pre-trained classifier-free guided diffusion models, a high-quality 768 x 768 2~4-step LCM takes only 32 A100 GPU hours for training. Furthermore, we introduce Latent Consistency Fine-tuning (LCF), a novel method that is tailored for fine-tuning LCMs on customized image datasets. Evaluation on the LAION-5B-Aesthetics dataset demonstrates that LCMs achieve state-of-the-art text-to-image generation performance with few-step inference.

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/7e531a3d-3256-461e-86e5-9e31311ac46e"/>
</div>

## Citation

```
@misc{luo2023latent,
title={Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference},
author={Simian Luo and Yiqin Tan and Longbo Huang and Jian Li and Hang Zhao},
year={2023},
eprint={2310.04378},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

## Run Training

Run Training

```
# single gpu
$ mim train diffengine ${CONFIG_FILE}
# multi gpus
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch

# Example.
$ mim train diffengine configs/lcm/lcm_xl_pokemon_blip.py
```

## Inference with diffusers

Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.

Before inferencing, we should convert weights for diffusers format,

```bash
$ mim run diffengine publish_model2diffusers ${CONFIG_FILE} ${INPUT_FILENAME} ${OUTPUT_DIR} --save-keys ${SAVE_KEYS}
# Example
$ mim run diffengine publish_model2diffusers configs/lcm/lcm_xl_pokemon_blip.py work_dirs/lcm_xl_pokemon_blip/epoch_50.pth work_dirs/lcm_xl_pokemon_blip --save-keys unet
```

Then we can run inference.

```py
import torch
from diffusers import DiffusionPipeline, AutoencoderKL, LCMScheduler, UNet2DConditionModel

checkpoint = 'work_dirs/lcm_xl_pokemon_blip'
prompt = 'yoda pokemon'

unet = UNet2DConditionModel.from_pretrained(
checkpoint, subfolder='unet', torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(
'madebyollin/sdxl-vae-fp16-fix',
torch_dtype=torch.float16,
)
pipe = DiffusionPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0',
unet=unet,
scheduler=LCMScheduler.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler"),
vae=vae, torch_dtype=torch.float16)
pipe.to('cuda')

image = pipe(
prompt,
num_inference_steps=4,
height=1024,
width=1024,
guidance_scale=1.0,
).images[0]
image.save('demo.png')
```

## Results Example

#### lcm_xl_pokemon_blip

![example1](https://github.com/okotaku/diffengine/assets/24734142/8fd9799d-11a3-4cd1-8f08-f91e9e7cef3c)
20 changes: 20 additions & 0 deletions configs/lcm/lcm_xl_pokemon_blip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
_base_ = [
"../_base_/models/lcm_xl.py",
"../_base_/datasets/pokemon_blip_xl_pre_compute.py",
"../_base_/schedules/lcm_xl_50e.py",
"../_base_/default_runtime.py",
]

train_dataloader = dict(batch_size=2)

optim_wrapper_cfg = dict(accumulative_counts=2) # update every four times

custom_hooks = [
dict(
type="VisualizationHook",
prompt=["yoda pokemon"] * 4,
height=1024,
width=1024),
dict(type="SDCheckpointHook"),
dict(type="LCMEMAUpdateHook"),
]
86 changes: 86 additions & 0 deletions configs/lcm_lora/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# LCM-LoRA

[LCM-LoRA: A Universal Stable-Diffusion Acceleration Module](https://arxiv.org/abs/2311.05556)

## Abstract

Latent Consistency Models (LCMs) have achieved impressive performance in accelerating text-to-image generative tasks, producing high-quality images with minimal inference steps. LCMs are distilled from pre-trained latent diffusion models (LDMs), requiring only ~32 A100 GPU training hours. This report further extends LCMs' potential in two aspects: First, by applying LoRA distillation to Stable-Diffusion models including SD-V1.5, SSD-1B, and SDXL, we have expanded LCM's scope to larger models with significantly less memory consumption, achieving superior image generation quality. Second, we identify the LoRA parameters obtained through LCM distillation as a universal Stable-Diffusion acceleration module, named LCM-LoRA. LCM-LoRA can be directly plugged into various Stable-Diffusion fine-tuned models or LoRAs without training, thus representing a universally applicable accelerator for diverse image generation tasks. Compared with previous numerical PF-ODE solvers such as DDIM, DPM-Solver, LCM-LoRA can be viewed as a plug-in neural PF-ODE solver that possesses strong generalization abilities.

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/d2d6be56-824f-4623-8a06-59726fb0e6b1"/>
</div>

## Citation

```
@article{luo2023lcm,nvi
title={LCM-LoRA: A Universal Stable-Diffusion Acceleration Module},
author={Luo, Simian and Tan, Yiqin and Patil, Suraj and Gu, Daniel and von Platen, Patrick and Passos, Apolin{\'a}rio and Huang, Longbo and Li, Jian and Zhao, Hang},
journal={arXiv preprint arXiv:2311.05556},
year={2023}
}
```

## Run Training

Run Training

```
# single gpu
$ mim train diffengine ${CONFIG_FILE}
# multi gpus
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch

# Example.
$ mim train diffengine configs/lcm_lora/lcm_xl_lora_pokemon_blip.py
```

## Inference with diffusers

Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.

```py
from pathlib import Path

import torch
from diffusers import DiffusionPipeline, AutoencoderKL, LCMScheduler
from peft import PeftModel

checkpoint = Path('work_dirs/lcm_xl_lora_pokemon_blip/step20850')
prompt = 'yoda pokemon'

vae = AutoencoderKL.from_pretrained(
'madebyollin/sdxl-vae-fp16-fix',
torch_dtype=torch.float16,
)
pipe = DiffusionPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', vae=vae,
scheduler=LCMScheduler.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler"),
torch_dtype=torch.float16)
pipe.to('cuda')
pipe.unet = PeftModel.from_pretrained(pipe.unet, checkpoint / "unet", adapter_name="default")
if (checkpoint / "text_encoder_one").exists():
pipe.text_encoder_one = PeftModel.from_pretrained(
pipe.text_encoder_one, checkpoint / "text_encoder_one", adapter_name="default"
)
if (checkpoint / "text_encoder_two").exists():
pipe.text_encoder_one = PeftModel.from_pretrained(
pipe.text_encoder_two, checkpoint / "text_encoder_two", adapter_name="default"
)

image = pipe(
prompt,
num_inference_steps=4,
guidance_scale=1.0,
height=1024,
width=1024,
).images[0]
image.save('demo.png')
```

## Results Example

#### lcm_xl_lora_pokemon_blip

![example1](https://github.com/okotaku/diffengine/assets/24734142/c321c36e-ba99-42f7-ab0f-4f790253926f)
Loading
Loading