generated from okotaku/template
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #125 from okotaku/feat/amused
[Feature] Support aMUSEd
- Loading branch information
Showing
20 changed files
with
935 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
47 changes: 47 additions & 0 deletions
47
diffengine/configs/_base_/datasets/pokemon_blip_amused_512.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torchvision | ||
from mmengine.dataset import DefaultSampler | ||
|
||
from diffengine.datasets import HFDataset | ||
from diffengine.datasets.transforms import ( | ||
ComputeaMUSEdMicroConds, | ||
PackInputs, | ||
RandomCrop, | ||
RandomHorizontalFlip, | ||
RandomTextDrop, | ||
SaveImageShape, | ||
TorchVisonTransformWrapper, | ||
) | ||
from diffengine.engine.hooks import TransformerCheckpointHook, VisualizationHook | ||
|
||
train_pipeline = [ | ||
dict(type=SaveImageShape), | ||
dict(type=TorchVisonTransformWrapper, | ||
transform=torchvision.transforms.Resize, | ||
size=512, interpolation="bilinear"), | ||
dict(type=RandomCrop, size=512), | ||
dict(type=RandomHorizontalFlip, p=0.5), | ||
dict(type=ComputeaMUSEdMicroConds), | ||
dict(type=TorchVisonTransformWrapper, | ||
transform=torchvision.transforms.ToTensor), | ||
dict(type=RandomTextDrop), | ||
dict(type=PackInputs, input_keys=["img", "text", "micro_conds"]), | ||
] | ||
train_dataloader = dict( | ||
batch_size=8, | ||
num_workers=4, | ||
dataset=dict( | ||
type=HFDataset, | ||
dataset="lambdalabs/pokemon-blip-captions", | ||
pipeline=train_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=True), | ||
) | ||
|
||
val_dataloader = None | ||
val_evaluator = None | ||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator | ||
|
||
custom_hooks = [ | ||
dict(type=VisualizationHook, prompt=["yoda pokemon"] * 4), | ||
dict(type=TransformerCheckpointHook), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from diffusers import UVit2DModel, VQModel | ||
from transformers import ( | ||
CLIPTextModelWithProjection, | ||
CLIPTokenizer, | ||
) | ||
|
||
from diffengine.models.editors import AMUSEd | ||
|
||
base_model = "amused/amused-512" | ||
model = dict(type=AMUSEd, | ||
model=base_model, | ||
tokenizer=dict(type=CLIPTokenizer.from_pretrained, | ||
subfolder="tokenizer"), | ||
text_encoder=dict(type=CLIPTextModelWithProjection.from_pretrained, | ||
subfolder="text_encoder"), | ||
vae=dict( | ||
type=VQModel.from_pretrained, | ||
subfolder="vqvae"), | ||
transformer=dict(type=UVit2DModel.from_pretrained, | ||
subfolder="transformer")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# aMUSEd | ||
|
||
[aMUSEd: An Open MUSE Reproduction](https://arxiv.org/abs/2401.01808) | ||
|
||
## Abstract | ||
|
||
We present aMUSEd, an open-source, lightweight masked image model (MIM) for text-to-image generation based on MUSE. With 10 percent of MUSE's parameters, aMUSEd is focused on fast image generation. We believe MIM is under-explored compared to latent diffusion, the prevailing approach for text-to-image generation. Compared to latent diffusion, MIM requires fewer inference steps and is more interpretable. Additionally, MIM can be fine-tuned to learn additional styles with only a single image. We hope to encourage further exploration of MIM by demonstrating its effectiveness on large-scale text-to-image generation and releasing reproducible training code. We also release checkpoints for two models which directly produce images at 256x256 and 512x512 resolutions. | ||
|
||
<div align=center> | ||
<img src="https://github.com/okotaku/diffengine/assets/24734142/d62b8007-2064-47ff-97c2-7c2377be3411"/> | ||
</div> | ||
|
||
## Citation | ||
|
||
``` | ||
@misc{patil2024amused, | ||
title={aMUSEd: An Open MUSE Reproduction}, | ||
author={Suraj Patil and William Berman and Robin Rombach and Patrick von Platen}, | ||
year={2024}, | ||
eprint={2401.01808}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CV} | ||
} | ||
``` | ||
|
||
## Run Training | ||
|
||
Run Training | ||
|
||
``` | ||
# single gpu | ||
$ diffengine train ${CONFIG_FILE} | ||
# multi gpus | ||
$ NPROC_PER_NODE=${GPU_NUM} diffengine train ${CONFIG_FILE} | ||
# Example. | ||
$ diffengine train amused_512_pokemon_blip | ||
``` | ||
|
||
## Inference with diffusers | ||
|
||
Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module. | ||
|
||
Before inferencing, we should convert weights for diffusers format, | ||
|
||
```bash | ||
$ diffengine convert ${CONFIG_FILE} ${INPUT_FILENAME} ${OUTPUT_DIR} --save-keys ${SAVE_KEYS} | ||
# Example | ||
$ diffengine convert amused_512_pokemon_blip work_dirs/amused_512_pokemon_blip/epoch_50.pth work_dirs/amused_512_pokemon_blip --save-keys transformer | ||
``` | ||
|
||
Then we can run inference. | ||
|
||
```py | ||
from pathlib import Path | ||
|
||
import torch | ||
from diffusers import AmusedPipeline, UVit2DModel | ||
from peft import PeftModel | ||
|
||
checkpoint = Path('work_dirs/amused_512_pokemon_blip') | ||
prompt = 'yoda pokemon' | ||
|
||
transformer = UVit2DModel.from_pretrained(checkpoint, subfolder='transformer') | ||
pipe = AmusedPipeline.from_pretrained( | ||
"amused/amused-512", | ||
transformer=transformer, | ||
torch_dtype=torch.float32, | ||
).to("cuda") | ||
|
||
img = pipe( | ||
prompt, | ||
width=512, | ||
height=512, | ||
).images[0] | ||
img.save("demo.png") | ||
``` | ||
|
||
## Results Example | ||
|
||
#### amused_512_pokemon_blip | ||
|
||
![example1](https://github.com/okotaku/diffengine/assets/24734142/a525dd2b-6663-42fb-8251-4d8767c19818) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .._base_.datasets.pokemon_blip_amused_512 import * | ||
from .._base_.default_runtime import * | ||
from .._base_.models.amused_512 import * | ||
from .._base_.schedules.stable_diffusion_50e import * | ||
|
||
optim_wrapper = dict( | ||
type=AmpOptimWrapper, | ||
dtype="float16", | ||
optimizer=dict(type=AdamW, lr=1e-4, weight_decay=1e-2), | ||
paramwise_cfg=dict( | ||
norm_decay_mult=0.0, | ||
bias_decay_mult=0.0, | ||
flat_decay_mult=0.0, | ||
custom_keys={ | ||
".mlm_ln.weight": dict(decay_mult=0.0), | ||
".embeddings.weight": dict(decay_mult=0.0), | ||
})) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .amused import AMUSEd | ||
from .amused_data_preprocessor import AMUSEdPreprocessor | ||
|
||
__all__ = ["AMUSEd", "AMUSEdPreprocessor"] |
Oops, something went wrong.