Skip to content

Commit

Permalink
Merge pull request #134 from okotaku/feat/ip_adapter_last_hidden
Browse files Browse the repository at this point in the history
[Featu] Support IP-Adapter Last hidden
  • Loading branch information
okotaku authored Feb 15, 2024
2 parents 665b109 + 9a841c8 commit 4234611
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 4 deletions.
6 changes: 6 additions & 0 deletions diffengine/configs/ip_adapter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,9 @@ You can see more details on [`docs/source/run_guides/run_ip_adapter.md`](../../d
![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/f76c33ba-c1ac-4f6f-b256-d48de5e58bf8)

#### stable_diffusion_xl_pokemon_blip_ip_adapter_plus_dinov2_giant_lasthidden

![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/4b37ce6c-60fd-4456-a542-74163927ee01)
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from mmengine.config import read_base
from transformers import AutoImageProcessor, Dinov2Model

with read_base():
from .._base_.datasets.pokemon_blip_xl_ip_adapter_dinov2_giant import *
from .._base_.default_runtime import *
from .._base_.models.stable_diffusion_xl_ip_adapter_plus import *
from .._base_.schedules.stable_diffusion_xl_50e import *


model.image_encoder = dict(
type=Dinov2Model.from_pretrained,
pretrained_model_name_or_path="facebook/dinov2-giant")
model.feature_extractor = dict(
type=AutoImageProcessor.from_pretrained,
pretrained_model_name_or_path="facebook/dinov2-giant")
model.update(dict(hidden_states_idx=-1))

train_dataloader.update(batch_size=1)

optim_wrapper.update(accumulative_counts=4) # update every four times

train_cfg.update(by_epoch=True, max_epochs=100)
21 changes: 17 additions & 4 deletions diffengine/models/editors/ip_adapter/ip_adapter_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import torch
from diffusers import DiffusionPipeline
from diffusers.models.embeddings import MultiIPAdapterImageProjection
from diffusers.utils import load_image
from PIL import Image
Expand All @@ -13,6 +12,9 @@
process_ip_adapter_state_dict,
set_unet_ip_adapter,
)
from diffengine.models.editors.ip_adapter.pipeline import (
StableDiffusionXLPipelineCustomIPAdapter,
)
from diffengine.models.editors.stable_diffusion_xl import StableDiffusionXL
from diffengine.registry import MODELS, TRANSFORMS

Expand Down Expand Up @@ -49,6 +51,8 @@ class IPAdapterXL(StableDiffusionXL):
generate zeros image embeddings. Defaults to 0.1.
data_preprocessor (dict, optional): The pre-process config of
:class:`SDControlNetDataPreprocessor`.
hidden_states_idx (int): Index of the hidden states to be used.
Defaults to -2.
"""

def __init__(self,
Expand All @@ -64,6 +68,7 @@ def __init__(self,
finetune_text_encoder: bool = False,
zeros_image_embeddings_prob: float = 0.1,
data_preprocessor: dict | nn.Module | None = None,
hidden_states_idx: int = -2,
**kwargs) -> None:
if data_preprocessor is None:
data_preprocessor = {"type": "IPAdapterXLDataPreprocessor"}
Expand All @@ -80,6 +85,7 @@ def __init__(self,
self.pretrained_adapter_subfolder = pretrained_adapter_subfolder
self.pretrained_adapter_weights_name = pretrained_adapter_weights_name
self.zeros_image_embeddings_prob = zeros_image_embeddings_prob
self.hidden_states_idx = hidden_states_idx

self.feature_extractor = TRANSFORMS.build(feature_extractor)

Expand Down Expand Up @@ -156,7 +162,7 @@ def infer(self,
orig_encoder_hid_proj = self.unet.encoder_hid_proj
orig_encoder_hid_dim_type = self.unet.config.encoder_hid_dim_type

pipeline = DiffusionPipeline.from_pretrained(
pipeline = StableDiffusionXLPipelineCustomIPAdapter.from_pretrained(
self.model,
vae=self.vae,
text_encoder=self.text_encoder_one,
Expand All @@ -168,6 +174,7 @@ def infer(self,
feature_extractor=self.feature_extractor,
torch_dtype=(torch.float16 if self.device != torch.device("cpu")
else torch.float32),
hidden_states_idx=self.hidden_states_idx,
)
adapter_state_dict = process_ip_adapter_state_dict(
self.unet, self.image_projection)
Expand Down Expand Up @@ -386,8 +393,14 @@ def forward(
replacement=True).to(clip_img)
clip_img = clip_img * mask.view(-1, 1, 1, 1)
# encode image
image_embeds = self.image_encoder(
clip_img, output_hidden_states=True).hidden_states[-2]
if self.hidden_states_idx == -1:
image_embeds = self.image_encoder(
clip_img, output_hidden_states=True,
).last_hidden_state
else:
image_embeds = self.image_encoder(
clip_img, output_hidden_states=True,
).hidden_states[self.hidden_states_idx]

ip_tokens = self.image_projection(image_embeds)

Expand Down
89 changes: 89 additions & 0 deletions diffengine/models/editors/ip_adapter/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# flake8: noqa
import torch
from diffusers import StableDiffusionXLPipeline


class StableDiffusionXLPipelineCustomIPAdapter(StableDiffusionXLPipeline):
"""Custom IP Adapter for the StableDiffusionXLPipeline class.
The difference between this class and the original
StableDiffusionXLPipeline class is that this class uses the hidden states
from the `hidden_states_idx` layer of the image encoder to encode the
image.
Args:
*args: Variable length argument list.
hidden_states_idx (int): Index of the hidden states to be used.
Defaults to -2.
**kwargs: Arbitrary keyword arguments.
"""

def __init__(self,
vae,
text_encoder,
text_encoder_2,
tokenizer,
tokenizer_2,
unet,
scheduler,
image_encoder=None,
feature_extractor=None,
force_zeros_for_empty_prompt=True,
add_watermarker=None,
hidden_states_idx: int = -2):
super().__init__(vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
add_watermarker=add_watermarker)
self.hidden_states_idx = hidden_states_idx

def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
"""Encodes the image.
Args:
image: The input image to be encoded.
device: The device to be used for encoding.
num_images_per_prompt: The number of images per prompt.
output_hidden_states: Whether to output hidden states. Defaults to None.
Returns:
image_enc_hidden_states: Encoded hidden states of the image.
uncond_image_enc_hidden_states: Encoded hidden states of the unconditional image.
"""
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
if output_hidden_states:
if self.hidden_states_idx == -1:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).last_hidden_state
else:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[self.hidden_states_idx]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
if self.hidden_states_idx == -1:
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).last_hidden_state
else:
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[self.hidden_states_idx]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0,
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)

return image_embeds, uncond_image_embeds
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ def test_infer(self):
assert type(result[0]) == torch.Tensor
assert result[0].shape == (4, 32, 32)

# test infer with hidden_states_idx=-1
cfg = self._get_config()
cfg.update(hidden_states_idx=-1)
StableDiffuser = MODELS.build(cfg)
result = StableDiffuser.infer(
["an insect robot preparing a delicious meal"],
["tests/testdata/color.jpg"],
negative_prompt="noise",
height=64,
width=64)
assert len(result) == 1
assert result[0].shape == (64, 64, 3)

def test_train_step(self):
# test load with loss module
cfg = self._get_config()
Expand Down Expand Up @@ -163,6 +176,25 @@ def test_train_step_with_gradient_checkpointing(self):
assert log_vars
assert isinstance(log_vars["loss"], torch.Tensor)

def test_train_step_last_hidden(self):
# test load with loss module
cfg = self._get_config()
cfg.update(hidden_states_idx=-1)
StableDiffuser = MODELS.build(cfg)

# test train step
data = dict(
inputs=dict(
img=[torch.zeros((3, 64, 64))],
text=["a dog"],
clip_img=[torch.zeros((3, 32, 32))],
time_ids=[torch.zeros((1, 6))]))
optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optimizer)
log_vars = StableDiffuser.train_step(data, optim_wrapper)
assert log_vars
assert isinstance(log_vars["loss"], torch.Tensor)

def test_val_and_test_step(self):
cfg = self._get_config()
StableDiffuser = MODELS.build(cfg)
Expand Down

0 comments on commit 4234611

Please sign in to comment.