Skip to content

Commit

Permalink
IP-Adapter DINOv2
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku committed Feb 1, 2024
1 parent 27ab2d0 commit a90303e
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torchvision
from mmengine.dataset import DefaultSampler

from diffengine.datasets import HFDataset
from diffengine.datasets.transforms import (
ComputeTimeIds,
PackInputs,
RandomCrop,
RandomHorizontalFlip,
RandomTextDrop,
SaveImageShape,
TorchVisonTransformWrapper,
TransformersImageProcessor,
)
from diffengine.engine.hooks import IPAdapterSaveHook, VisualizationHook

train_pipeline = [
dict(type=SaveImageShape),
dict(type=TransformersImageProcessor,
pretrained="facebook/dinov2-base"),
dict(type=RandomTextDrop),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Resize,
size=1024, interpolation="bilinear"),
dict(type=RandomCrop, size=1024),
dict(type=RandomHorizontalFlip, p=0.5),
dict(type=ComputeTimeIds),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.ToTensor),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Normalize, mean=[0.5], std=[0.5]),
dict(
type=PackInputs, input_keys=["img", "text", "time_ids", "clip_img"]),
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(
type=HFDataset,
dataset="lambdalabs/pokemon-blip-captions",
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type=VisualizationHook,
prompt=["a drawing of a green pokemon with red eyes"] * 2 + [""] * 2,
example_image=[
'https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg' # noqa
] * 4,
height=1024,
width=1024),
dict(type=IPAdapterSaveHook),
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from diffusers.models.embeddings import ImageProjection
from transformers import (
AutoTokenizer,
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
Expand Down Expand Up @@ -34,4 +35,5 @@
subfolder="sdxl_models/image_encoder"),
image_projection=dict(type=ImageProjection,
num_image_text_embeds=4),
feature_extractor=dict(type=CLIPImageProcessor),
gradient_checkpointing=True)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from diffusers.models.embeddings import IPAdapterPlusImageProjection
from transformers import (
AutoTokenizer,
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
Expand Down Expand Up @@ -39,4 +40,5 @@
heads=20,
num_queries=16,
ffn_ratio=4),
feature_extractor=dict(type=CLIPImageProcessor),
gradient_checkpointing=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from mmengine.config import read_base
from transformers import AutoImageProcessor, Dinov2Model

with read_base():
from .._base_.datasets.pokemon_blip_xl_ip_adapter_dinov2 import *
from .._base_.default_runtime import *
from .._base_.models.stable_diffusion_xl_ip_adapter_plus import *
from .._base_.schedules.stable_diffusion_xl_50e import *


model.image_encoder = dict(
type=Dinov2Model.from_pretrained,
pretrained_model_name_or_path="facebook/dinov2-base")
model.feature_extractor = dict(
type=AutoImageProcessor.from_pretrained,
pretrained_model_name_or_path="facebook/dinov2-base")

train_dataloader.update(batch_size=1)

optim_wrapper.update(accumulative_counts=4) # update every four times
2 changes: 2 additions & 0 deletions diffengine/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SaveImageShape,
T5TextPreprocess,
TorchVisonTransformWrapper,
TransformersImageProcessor,
)
from .wrappers import RandomChoice

Expand Down Expand Up @@ -47,4 +48,5 @@
"TorchVisonTransformWrapper",
"ConcatMultipleImgs",
"ComputeaMUSEdMicroConds",
"TransformersImageProcessor",
]
33 changes: 33 additions & 0 deletions diffengine/datasets/transforms/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mmengine.dataset.base_dataset import Compose
from torchvision.transforms.functional import crop
from torchvision.transforms.transforms import InterpolationMode
from transformers import AutoImageProcessor
from transformers import CLIPImageProcessor as HFCLIPImageProcessor

from diffengine.datasets.transforms.base import BaseTransform
Expand Down Expand Up @@ -936,3 +937,35 @@ def transform(self, results: dict) -> dict | tuple[list, list] | None:
micro_conds = micro_conds[0]
results["micro_conds"] = micro_conds
return results


@TRANSFORMS.register_module()
class TransformersImageProcessor(BaseTransform):
"""TransformersImageProcessor.
Args:
----
key (str): `key` to apply augmentation from results. Defaults to 'img'.
output_key (str): `output_key` after applying augmentation from
results. Defaults to 'clip_img'.
"""

def __init__(self, key: str = "img", output_key: str = "clip_img",
pretrained: str | None = None) -> None:
self.key = key
self.output_key = output_key
self.pipeline = AutoImageProcessor.from_pretrained(pretrained)

def transform(self, results: dict) -> dict | tuple[list, list] | None:
"""Transform.
Args:
----
results (dict): The result dict.
"""
assert not isinstance(results[self.key], list), (
"CLIPImageProcessor only support single image.")
# (1, 3, 224, 224) -> (3, 224, 224)
results[self.output_key] = self.pipeline(
images=results[self.key], return_tensors="pt").pixel_values[0]
return results
11 changes: 6 additions & 5 deletions diffengine/models/editors/ip_adapter/ip_adapter_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
from diffusers.utils import load_image
from PIL import Image
from torch import nn
from transformers import CLIPImageProcessor

from diffengine.models.archs import (
load_ip_adapter,
process_ip_adapter_state_dict,
set_unet_ip_adapter,
)
from diffengine.models.editors.stable_diffusion_xl import StableDiffusionXL
from diffengine.registry import MODELS
from diffengine.registry import MODELS, TRANSFORMS


@MODELS.register_module()
Expand All @@ -25,6 +24,7 @@ class IPAdapterXL(StableDiffusionXL):
----
image_encoder (dict): The image encoder config.
image_projection (dict): The image projection config.
feature_extractor (dict): The feature extractor config.
pretrained_adapter (str, optional): Path to pretrained IP-Adapter.
Defaults to None.
pretrained_adapter_subfolder (str, optional): Sub folder of pretrained
Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(self,
*args,
image_encoder: dict,
image_projection: dict,
feature_extractor: dict,
pretrained_adapter: str | None = None,
pretrained_adapter_subfolder: str = "",
pretrained_adapter_weights_name: str = "",
Expand All @@ -79,6 +80,8 @@ def __init__(self,
self.pretrained_adapter_weights_name = pretrained_adapter_weights_name
self.zeros_image_embeddings_prob = zeros_image_embeddings_prob

self.feature_extractor = TRANSFORMS.build(feature_extractor)

super().__init__(
*args,
unet_lora_config=unet_lora_config,
Expand Down Expand Up @@ -161,7 +164,7 @@ def infer(self,
tokenizer_2=self.tokenizer_two,
unet=self.unet,
image_encoder=self.image_encoder,
feature_extractor=CLIPImageProcessor(),
feature_extractor=self.feature_extractor,
torch_dtype=(torch.float16 if self.device != torch.device("cpu")
else torch.float32),
)
Expand Down Expand Up @@ -274,7 +277,6 @@ def forward(
replacement=True).to(image_embeds)
image_embeds = image_embeds * mask.view(-1, 1, 1, 1)

# TODO(takuoko): drop image # noqa
ip_tokens = self.image_projection(image_embeds)
prompt_embeds = torch.cat([prompt_embeds, ip_tokens], dim=1)

Expand Down Expand Up @@ -377,7 +379,6 @@ def forward(
image_embeds = self.image_encoder(
clip_img, output_hidden_states=True).hidden_states[-2]

# TODO(takuoko): drop image # noqa
ip_tokens = self.image_projection(image_embeds)
prompt_embeds = torch.cat([prompt_embeds, ip_tokens], dim=1)

Expand Down

0 comments on commit a90303e

Please sign in to comment.