Skip to content

Commit

Permalink
Merge pull request #125 from okotaku/feat/amused
Browse files Browse the repository at this point in the history
[Feature] Support aMUSEd
  • Loading branch information
okotaku committed Jan 12, 2024
2 parents c0610dd + 228772e commit e611420
Show file tree
Hide file tree
Showing 20 changed files with 935 additions and 11 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,22 @@ For detailed user guides and advanced guides, please refer to our [Documentation
</ul>
</td>
</tr>
</td>
</tr>
</tbody>
<tbody>
<tr align="center" valign="bottom">
<td>
<b>aMUSEd</b>
</td>
</tr>
<tr valign="top">
<td>
<ul>
<li><a href="diffengine/configs/amused/README.md">aMUSEd (2024)</a></li>
</ul>
</td>
</tr>
</td>
</tr>
</tbody>
Expand Down
47 changes: 47 additions & 0 deletions diffengine/configs/_base_/datasets/pokemon_blip_amused_512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torchvision
from mmengine.dataset import DefaultSampler

from diffengine.datasets import HFDataset
from diffengine.datasets.transforms import (
ComputeaMUSEdMicroConds,
PackInputs,
RandomCrop,
RandomHorizontalFlip,
RandomTextDrop,
SaveImageShape,
TorchVisonTransformWrapper,
)
from diffengine.engine.hooks import TransformerCheckpointHook, VisualizationHook

train_pipeline = [
dict(type=SaveImageShape),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Resize,
size=512, interpolation="bilinear"),
dict(type=RandomCrop, size=512),
dict(type=RandomHorizontalFlip, p=0.5),
dict(type=ComputeaMUSEdMicroConds),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.ToTensor),
dict(type=RandomTextDrop),
dict(type=PackInputs, input_keys=["img", "text", "micro_conds"]),
]
train_dataloader = dict(
batch_size=8,
num_workers=4,
dataset=dict(
type=HFDataset,
dataset="lambdalabs/pokemon-blip-captions",
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(type=VisualizationHook, prompt=["yoda pokemon"] * 4),
dict(type=TransformerCheckpointHook),
]
4 changes: 2 additions & 2 deletions diffengine/configs/_base_/datasets/pokemon_blip_pixart.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
T5TextPreprocess,
TorchVisonTransformWrapper,
)
from diffengine.engine.hooks import PixArtCheckpointHook, VisualizationHook
from diffengine.engine.hooks import TransformerCheckpointHook, VisualizationHook

train_pipeline = [
dict(type=SaveImageShape),
Expand Down Expand Up @@ -50,5 +50,5 @@
prompt=["yoda pokemon"] * 4,
height=1024,
width=1024),
dict(type=PixArtCheckpointHook),
dict(type=TransformerCheckpointHook),
]
20 changes: 20 additions & 0 deletions diffengine/configs/_base_/models/amused_512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from diffusers import UVit2DModel, VQModel
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
)

from diffengine.models.editors import AMUSEd

base_model = "amused/amused-512"
model = dict(type=AMUSEd,
model=base_model,
tokenizer=dict(type=CLIPTokenizer.from_pretrained,
subfolder="tokenizer"),
text_encoder=dict(type=CLIPTextModelWithProjection.from_pretrained,
subfolder="text_encoder"),
vae=dict(
type=VQModel.from_pretrained,
subfolder="vqvae"),
transformer=dict(type=UVit2DModel.from_pretrained,
subfolder="transformer"))
83 changes: 83 additions & 0 deletions diffengine/configs/amused/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# aMUSEd

[aMUSEd: An Open MUSE Reproduction](https://arxiv.org/abs/2401.01808)

## Abstract

We present aMUSEd, an open-source, lightweight masked image model (MIM) for text-to-image generation based on MUSE. With 10 percent of MUSE's parameters, aMUSEd is focused on fast image generation. We believe MIM is under-explored compared to latent diffusion, the prevailing approach for text-to-image generation. Compared to latent diffusion, MIM requires fewer inference steps and is more interpretable. Additionally, MIM can be fine-tuned to learn additional styles with only a single image. We hope to encourage further exploration of MIM by demonstrating its effectiveness on large-scale text-to-image generation and releasing reproducible training code. We also release checkpoints for two models which directly produce images at 256x256 and 512x512 resolutions.

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/d62b8007-2064-47ff-97c2-7c2377be3411"/>
</div>

## Citation

```
@misc{patil2024amused,
title={aMUSEd: An Open MUSE Reproduction},
author={Suraj Patil and William Berman and Robin Rombach and Patrick von Platen},
year={2024},
eprint={2401.01808},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

## Run Training

Run Training

```
# single gpu
$ diffengine train ${CONFIG_FILE}
# multi gpus
$ NPROC_PER_NODE=${GPU_NUM} diffengine train ${CONFIG_FILE}
# Example.
$ diffengine train amused_512_pokemon_blip
```

## Inference with diffusers

Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.

Before inferencing, we should convert weights for diffusers format,

```bash
$ diffengine convert ${CONFIG_FILE} ${INPUT_FILENAME} ${OUTPUT_DIR} --save-keys ${SAVE_KEYS}
# Example
$ diffengine convert amused_512_pokemon_blip work_dirs/amused_512_pokemon_blip/epoch_50.pth work_dirs/amused_512_pokemon_blip --save-keys transformer
```

Then we can run inference.

```py
from pathlib import Path

import torch
from diffusers import AmusedPipeline, UVit2DModel
from peft import PeftModel

checkpoint = Path('work_dirs/amused_512_pokemon_blip')
prompt = 'yoda pokemon'

transformer = UVit2DModel.from_pretrained(checkpoint, subfolder='transformer')
pipe = AmusedPipeline.from_pretrained(
"amused/amused-512",
transformer=transformer,
torch_dtype=torch.float32,
).to("cuda")

img = pipe(
prompt,
width=512,
height=512,
).images[0]
img.save("demo.png")
```

## Results Example

#### amused_512_pokemon_blip

![example1](https://github.com/okotaku/diffengine/assets/24734142/a525dd2b-6663-42fb-8251-4d8767c19818)
20 changes: 20 additions & 0 deletions diffengine/configs/amused/amused_512_pokemon_blip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from mmengine.config import read_base

with read_base():
from .._base_.datasets.pokemon_blip_amused_512 import *
from .._base_.default_runtime import *
from .._base_.models.amused_512 import *
from .._base_.schedules.stable_diffusion_50e import *

optim_wrapper = dict(
type=AmpOptimWrapper,
dtype="float16",
optimizer=dict(type=AdamW, lr=1e-4, weight_decay=1e-2),
paramwise_cfg=dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
flat_decay_mult=0.0,
custom_keys={
".mlm_ln.weight": dict(decay_mult=0.0),
".embeddings.weight": dict(decay_mult=0.0),
}))
2 changes: 2 additions & 0 deletions diffengine/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
AddConstantCaption,
CenterCrop,
CLIPImageProcessor,
ComputeaMUSEdMicroConds,
ComputePixArtImgInfo,
ComputeTimeIds,
ConcatMultipleImgs,
Expand Down Expand Up @@ -45,4 +46,5 @@
"DumpMaskedImage",
"TorchVisonTransformWrapper",
"ConcatMultipleImgs",
"ComputeaMUSEdMicroConds",
]
40 changes: 40 additions & 0 deletions diffengine/datasets/transforms/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,43 @@ def transform(self,
for k in self.keys:
results[k] = torch.cat(results[k], dim=0)
return results


@TRANSFORMS.register_module()
class ComputeaMUSEdMicroConds(BaseTransform):
"""Compute aMUSEd micro_conds as 'micro_conds' in results."""

def transform(self, results: dict) -> dict | tuple[list, list] | None:
"""Transform.
Args:
----
results (dict): The result dict.
Returns:
-------
dict: 'micro_conds' key is added as original image shape.
"""
assert "ori_img_shape" in results
assert "crop_top_left" in results

micro_conds = []
if not isinstance(results["img"], list):
img = [results["img"]]
ori_img_shape = [results["ori_img_shape"]]
crop_top_left = [results["crop_top_left"]]
else:
img = results["img"]
ori_img_shape = results["ori_img_shape"]
crop_top_left = results["crop_top_left"]

for i in range(len(img)):
# ori_img_shape [H, W] -> [W, H]
aesthetic_score = 6.0
micro_conds.append(
ori_img_shape[i][::-1] + crop_top_left[i] + [aesthetic_score])

if not isinstance(results["img"], list):
micro_conds = micro_conds[0]
results["micro_conds"] = micro_conds
return results
4 changes: 2 additions & 2 deletions diffengine/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from .ip_adapter_save_hook import IPAdapterSaveHook
from .lcm_ema_update_hook import LCMEMAUpdateHook
from .peft_save_hook import PeftSaveHook
from .pixart_checkpoint_hook import PixArtCheckpointHook
from .prior_save_hook import PriorSaveHook
from .sd_checkpoint_hook import SDCheckpointHook
from .t2i_adapter_save_hook import T2IAdapterSaveHook
from .transformer_checkpoint_hook import TransformerCheckpointHook
from .unet_ema_hook import UnetEMAHook
from .visualization_hook import VisualizationHook

Expand All @@ -23,5 +23,5 @@
"FastNormHook",
"PriorSaveHook",
"LCMEMAUpdateHook",
"PixArtCheckpointHook",
"TransformerCheckpointHook",
]
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@HOOKS.register_module()
class PixArtCheckpointHook(Hook):
class TransformerCheckpointHook(Hook):
"""Delete 'vae' from checkpoint for efficient save."""

priority = "VERY_LOW"
Expand Down
1 change: 1 addition & 0 deletions diffengine/models/editors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .amused import * # noqa: F403
from .deepfloyd_if import * # noqa: F403
from .distill_sd import * # noqa: F403
from .esd import * # noqa: F403
Expand Down
4 changes: 4 additions & 0 deletions diffengine/models/editors/amused/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .amused import AMUSEd
from .amused_data_preprocessor import AMUSEdPreprocessor

__all__ = ["AMUSEd", "AMUSEdPreprocessor"]
Loading

0 comments on commit e611420

Please sign in to comment.