Skip to content

Commit

Permalink
Fix IP-Adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku committed Feb 7, 2024
1 parent 71631e9 commit 9bf56e9
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions diffengine/models/editors/ip_adapter/ip_adapter_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch
from diffusers import DiffusionPipeline
from diffusers.models.embeddings import MultiIPAdapterImageProjection
from diffusers.utils import load_image
from PIL import Image
from torch import nn
Expand Down Expand Up @@ -167,9 +168,21 @@ def infer(self,
)
adapter_state_dict = process_ip_adapter_state_dict(
self.unet, self.image_projection)
pipeline.load_ip_adapter(
pretrained_model_name_or_path_or_dict=adapter_state_dict,
subfolder="", weight_name="")

# convert IP-Adapter Image Projection layers to diffusers
image_projection_layers = []
for state_dict in [adapter_state_dict]:
image_projection_layer = (
pipeline.unet._convert_ip_adapter_image_proj_to_diffusers( # noqa
state_dict["image_proj"]))
image_projection_layer.to(
device=pipeline.unet.device, dtype=pipeline.unet.dtype)
image_projection_layers.append(image_projection_layer)

pipeline.unet.encoder_hid_proj = MultiIPAdapterImageProjection(
image_projection_layers)
pipeline.unet.config.encoder_hid_dim_type = "ip_image_proj"

if self.prediction_type is not None:
# set prediction_type of scheduler if defined
scheduler_args = {"prediction_type": self.prediction_type}
Expand Down Expand Up @@ -276,12 +289,11 @@ def forward(

# TODO(takuoko): drop image # noqa
ip_tokens = self.image_projection(image_embeds)
prompt_embeds = torch.cat([prompt_embeds, ip_tokens], dim=1)

model_pred = self.unet(
noisy_latents,
timesteps,
prompt_embeds,
(prompt_embeds, ip_tokens),
added_cond_kwargs=unet_added_conditions).sample

return self.loss(model_pred, noise, latents, timesteps, weight)
Expand Down Expand Up @@ -379,12 +391,11 @@ def forward(

# TODO(takuoko): drop image # noqa
ip_tokens = self.image_projection(image_embeds)
prompt_embeds = torch.cat([prompt_embeds, ip_tokens], dim=1)

model_pred = self.unet(
noisy_latents,
timesteps,
prompt_embeds,
(prompt_embeds, ip_tokens),
added_cond_kwargs=unet_added_conditions).sample

return self.loss(model_pred, noise, latents, timesteps, weight)

0 comments on commit 9bf56e9

Please sign in to comment.