Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Update libs #129

Merged
merged 3 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Update libs
  • Loading branch information
okotaku committed Feb 5, 2024
commit f0dab7108c3c87b67ebb0fca438a437523892022
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
],
"service": "diffengine",
"workspaceFolder": "/workspace",
"postCreateCommand": "pre-commit install",
"postCreateCommand": "pre-commit install && gh auth login && gh extension install github/gh-copilot",
"customizations": {
"vscode": {
"extensions": [
Expand Down
16 changes: 9 additions & 7 deletions Dockerfile
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FROM nvcr.io/nvidia/pytorch:23.12-py3
FROM nvcr.io/nvidia/pytorch:24.01-py3

RUN apt update -y && apt install -y \
git tmux
git tmux gh
RUN apt-get update && apt-get install -y \
vim \
libgl1-mesa-dev \
Expand All @@ -19,13 +19,15 @@ RUN pip install --upgrade pip

# Install xformers
RUN pip install ninja
RUN export TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6 9.0+PTX" MAX_JOBS=1 && \
pip install -v -U git+https://github.com/facebookresearch/xformers.git@v0.0.20#egg=xformers
RUN export TORCH_CUDA_ARCH_LIST="8.6 9.0+PTX" MAX_JOBS=8 && \
pip install -v -U git+https://github.com/facebookresearch/xformers.git@v0.0.24#egg=xformers

# Install DiffEngine
RUN pip install --no-cache-dir openmim==0.3.9 && \
pip install . && \
pip install pre-commit
RUN pip install . && \
pip install pre-commit && \
pip uninstall -y $(pip list --format=freeze | grep opencv) && \
rm -rf /usr/local/lib/python3.10/dist-packages/cv2/ && \
pip install opencv-python-headless

# Language settings
ENV LANG C.UTF-8
Expand Down
8 changes: 4 additions & 4 deletions diffengine/configs/ip_adapter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pipe = DiffusionPipeline.from_pretrained(
pipe.to('cuda')
pipe.load_ip_adapter("work_dirs/stable_diffusion_xl_pokemon_blip_ip_adapter/step41650", subfolder="", weight_name="ip_adapter.bin")

image = load_image("https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg")
image = load_image("https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true")

image = pipe(
prompt,
Expand All @@ -80,18 +80,18 @@ You can see more details on [`docs/source/run_guides/run_ip_adapter.md`](../../d

#### stable_diffusion_xl_pokemon_blip_ip_adapter

![input1](https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg)
![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/6137ffb4-dff9-41de-aa6e-2910d95e6d21)

#### stable_diffusion_xl_pokemon_blip_ip_adapter_plus

![input1](https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg)
![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/723ad39d-9e0f-441b-80f7-cf9bcfd12853)

#### stable_diffusion_xl_pokemon_blip_ip_adapter_plus_pretrained

![input1](https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg)
![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/ace81220-010b-44a5-aa8f-3acdf3f54433)
12 changes: 6 additions & 6 deletions diffengine/configs/stable_diffusion_xl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ $ diffengine train stable_diffusion_xl_pokemon_blip
Environment:

- A6000 Single GPU
- nvcr.io/nvidia/pytorch:23.10-py3
- nvcr.io/nvidia/pytorch:24.01-py3

Settings:

- 1epoch training.

| Model | total time |
| :---------------------------------------: | :--------: |
| stable_diffusion_xl_pokemon_blip (fp16) | 12 m 37 s |
| stable_diffusion_xl_pokemon_blip_xformers | 10 m 6 s |
| stable_diffusion_xl_pokemon_blip_fast | 9 m 47 s |
| Model | total time |
| :------------------------------------------------------------------------: | :--------: |
| stable_diffusion_xl_pokemon_blip (fp16 / nvcr.io/nvidia/pytorch:23.10-py3) | 12 m 37 s |
| stable_diffusion_xl_pokemon_blip_xformers | 10 m 4 s |
| stable_diffusion_xl_pokemon_blip_fast | 9 m 36 s |

Note that `stable_diffusion_xl_pokemon_blip_fast` took a few minutes to compile. We will disregard it.

Expand Down
14 changes: 14 additions & 0 deletions diffengine/engine/hooks/visualization_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ def __init__(self,
self.height = height
self.width = width

def before_train(self, runner: Runner) -> None:
"""Before train hook."""
model = runner.model
if is_model_wrapper(model):
model = model.module
images = model.infer(
self.prompt,
height=self.height,
width=self.width,
**self.kwargs)
for i, image in enumerate(images):
runner.visualizer.add_image(
f"image{i}_step", image, step=runner.iter)

def after_train_iter(
self,
runner: Runner,
Expand Down
106 changes: 50 additions & 56 deletions diffengine/models/archs/ip_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from collections import OrderedDict

import torch
Expand Down Expand Up @@ -55,7 +56,7 @@ def set_unet_ip_adapter(unet: nn.Module) -> None:
unet.set_attn_processor(attn_procs)


def load_ip_adapter( # noqa: PLR0915, C901, PLR0912
def load_ip_adapter( # noqa: C901, PLR0912
unet: nn.Module,
image_projection: nn.Module,
pretrained_adapter: str,
Expand Down Expand Up @@ -101,70 +102,59 @@ def load_ip_adapter( # noqa: PLR0915, C901, PLR0912
if cross_attention_dim is None or "motion_modules" in name:
continue
value_dict = {}
for k in attn_proc.state_dict():
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
value_dict.update(
{"to_k_ip.0.weight":
state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update(
{"to_v_ip.0.weight":
state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})

attn_proc.load_state_dict(value_dict)
key_id += 2

image_proj_state_dict = {}
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
},
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict
for key, value in state_dict["image_proj"].items():
diffusers_name = key.replace("proj", "image_embeds")
image_proj_state_dict[diffusers_name] = value
elif "proj.3.weight" in state_dict["image_proj"]:
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"],
"ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"],
"ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"],
"ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"],
"norm.weight": state_dict["image_proj"]["proj.3.weight"],
"norm.bias": state_dict["image_proj"]["proj.3.bias"],
},
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict
# IP-Adapter Full
for key, value in state_dict["image_proj"].items():
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
diffusers_name = diffusers_name.replace("proj.3", "norm")
image_proj_state_dict[diffusers_name] = value
else:
# IP-Adapter Plus
new_sd = OrderedDict()
for k, v in state_dict["image_proj"].items():
if "0.to" in k:
new_k = k.replace("0.to", "2.to")
elif "1.0.weight" in k:
new_k = k.replace("1.0.weight", "3.0.weight")
elif "1.0.bias" in k:
new_k = k.replace("1.0.bias", "3.0.bias")
elif "1.1.weight" in k:
new_k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
elif "1.3.weight" in k:
new_k = k.replace("1.3.weight", "3.1.net.2.weight")
for key, value in state_dict["image_proj"].items():
diffusers_name = key.replace("0.to", "2.to")
diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
diffusers_name = diffusers_name.replace(
"1.1.weight", "3.1.net.0.proj.weight")
diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")

if "norm1" in diffusers_name:
image_proj_state_dict[
diffusers_name.replace("0.norm1", "0")] = value
elif "norm2" in diffusers_name:
image_proj_state_dict[
diffusers_name.replace("0.norm2", "1")] = value
elif "to_kv" in diffusers_name:
v_chunk = value.chunk(2, dim=0)
image_proj_state_dict[
diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
image_proj_state_dict[
diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in diffusers_name:
image_proj_state_dict[
diffusers_name.replace("to_out", "to_out.0")] = value
else:
new_k = k
image_proj_state_dict[diffusers_name] = value

if "norm1" in new_k:
new_sd[new_k.replace("0.norm1", "0")] = v
elif "norm2" in new_k:
new_sd[new_k.replace("0.norm2", "1")] = v
elif "to_kv" in new_k:
v_chunk = v.chunk(2, dim=0)
new_sd[new_k.replace("to_kv", "to_k")] = v_chunk[0]
new_sd[new_k.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in new_k:
new_sd[new_k.replace("to_out", "to_out.0")] = v
else:
new_sd[new_k] = v
image_projection.load_state_dict(new_sd)
del state_dict
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict, state_dict
torch.cuda.empty_cache()


Expand All @@ -173,7 +163,11 @@ def process_ip_adapter_state_dict( # noqa: PLR0915, C901, PLR0912
"""Process IP-Adapter state dict."""
adapter_modules = torch.nn.ModuleList([
v if isinstance(v, nn.Module) else nn.Identity(
) for v in unet.attn_processors.values()])
) for v in copy.deepcopy(unet.attn_processors).values()])
adapter_state_dict = OrderedDict()
for k, v in adapter_modules.state_dict().items():
new_k = k.replace(".0.weight", ".weight")
adapter_state_dict[new_k] = v

# not save no grad key
ip_image_projection_state_dict = OrderedDict()
Expand Down Expand Up @@ -234,4 +228,4 @@ def process_ip_adapter_state_dict( # noqa: PLR0915, C901, PLR0912
ip_image_projection_state_dict[new_k] = v

return {"image_proj": ip_image_projection_state_dict,
"ip_adapter": adapter_modules.state_dict()}
"ip_adapter": adapter_state_dict}
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ version = "1.0.0"
requires-python = ">= 3.10"
readme = "README.md"
dependencies = [
"torch>=2.0.1",
"torchvision>=0.15.2",
"datasets>=2.14.6",
"diffusers>=0.25.0",
"mmengine>=0.10.1",
"torch>=2.1.2",
"torchvision>=0.16.2",
"datasets>=2.16.1",
"diffusers@git+https://github.com/huggingface/diffusers@ec9840a#egg=diffusers",
"mmengine>=0.10.3",
"sentencepiece>=0.1.99",
"tqdm",
"transformers>=4.35.2",
"transformers>=4.37.2",
"ujson",
"peft>=0.7.0",
"peft>=0.8.1",
"joblib",
]
license = { file = "LICENSE" }
Expand Down
19 changes: 19 additions & 0 deletions tests/test_engine/test_hooks/test_visualization_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,25 @@ def tearDown(self):
MODELS.module_dict.pop("TimeSteps")
return super().tearDown()

def test_before_train(self):
runner = MagicMock()

# test epoch-based
runner.train_loop = MagicMock(spec=EpochBasedTrainLoop)
runner.epoch = 5
hook = VisualizationHook(prompt=["a dog"])
hook.before_train(runner)

def test_before_train_with_condition(self):
runner = MagicMock()

# test epoch-based
runner.train_loop = MagicMock(spec=EpochBasedTrainLoop)
runner.epoch = 5
hook = VisualizationHook(
prompt=["a dog"], condition_image=["testdata/color.jpg"])
hook.before_train(runner)

def test_after_train_epoch(self):
runner = MagicMock()

Expand Down
Loading