Skip to content

Commit

Permalink
Support IP-Adapter DINO
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku committed Feb 10, 2024
1 parent 232ca95 commit a474b1a
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
type=VisualizationHook,
prompt=["a drawing of a green pokemon with red eyes"] * 2 + [""] * 2,
example_image=[
'https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg' # noqa
'https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true' # noqa
] * 4,
height=1024,
width=1024),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torchvision
from mmengine.dataset import DefaultSampler

from diffengine.datasets import HFDataset
from diffengine.datasets.transforms import (
ComputeTimeIds,
PackInputs,
RandomCrop,
RandomHorizontalFlip,
RandomTextDrop,
SaveImageShape,
TorchVisonTransformWrapper,
TransformersImageProcessor,
)
from diffengine.engine.hooks import IPAdapterSaveHook, VisualizationHook

train_pipeline = [
dict(type=SaveImageShape),
dict(type=TransformersImageProcessor,
pretrained="facebook/dinov2-giant"),
dict(type=RandomTextDrop),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Resize,
size=1024, interpolation="bilinear"),
dict(type=RandomCrop, size=1024),
dict(type=RandomHorizontalFlip, p=0.5),
dict(type=ComputeTimeIds),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.ToTensor),
dict(type=TorchVisonTransformWrapper,
transform=torchvision.transforms.Normalize, mean=[0.5], std=[0.5]),
dict(
type=PackInputs, input_keys=["img", "text", "time_ids", "clip_img"]),
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(
type=HFDataset,
dataset="lambdalabs/pokemon-blip-captions",
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type=VisualizationHook,
prompt=["a drawing of a green pokemon with red eyes"] * 2 + [""] * 2,
example_image=[
'https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true' # noqa
] * 4,
height=1024,
width=1024),
dict(type=IPAdapterSaveHook),
]
12 changes: 12 additions & 0 deletions diffengine/configs/ip_adapter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,15 @@ You can see more details on [`docs/source/run_guides/run_ip_adapter.md`](../../d
![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/ace81220-010b-44a5-aa8f-3acdf3f54433)

#### stable_diffusion_xl_pokemon_blip_ip_adapter_plus_dinov2

![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/5e1e2088-d00b-4909-9c64-61a7b5ac6b44)

#### stable_diffusion_xl_pokemon_blip_ip_adapter_plus_dinov2_giant

![input1](https://github.com/LambdaLabsML/examples/blob/main/stable-diffusion-finetuning/README_files/README_2_0.png?raw=true)

![example1](https://github.com/okotaku/diffengine/assets/24734142/f76c33ba-c1ac-4f6f-b256-d48de5e58bf8)
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@
train_dataloader.update(batch_size=1)

optim_wrapper.update(accumulative_counts=4) # update every four times

train_cfg.update(by_epoch=True, max_epochs=100)
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from mmengine.config import read_base
from transformers import AutoImageProcessor, Dinov2Model

with read_base():
from .._base_.datasets.pokemon_blip_xl_ip_adapter_dinov2_giant import *
from .._base_.default_runtime import *
from .._base_.models.stable_diffusion_xl_ip_adapter_plus import *
from .._base_.schedules.stable_diffusion_xl_50e import *


model.image_encoder = dict(
type=Dinov2Model.from_pretrained,
pretrained_model_name_or_path="facebook/dinov2-giant")
model.feature_extractor = dict(
type=AutoImageProcessor.from_pretrained,
pretrained_model_name_or_path="facebook/dinov2-giant")

train_dataloader.update(batch_size=1)

optim_wrapper.update(accumulative_counts=4) # update every four times

train_cfg.update(by_epoch=True, max_epochs=100)

0 comments on commit a474b1a

Please sign in to comment.