Skip to content

Commit

Permalink
Merge pull request #132 from okotaku/feat/support_vae_batch
Browse files Browse the repository at this point in the history
[Feature] Support VAE batch
  • Loading branch information
okotaku committed Feb 8, 2024
2 parents f16b0cc + d4cfb35 commit 1f9006f
Show file tree
Hide file tree
Showing 17 changed files with 123 additions and 42 deletions.
18 changes: 15 additions & 3 deletions diffengine/models/editors/amused/amused.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class AMUSEd(BaseModel):
It works when training dreambooth with class images.
data_preprocessor (dict, optional): The pre-process config of
:class:`SDDataPreprocessor`.
vae_batch_size (int): The batch size of vae. Defaults to 8.
finetune_text_encoder (bool, optional): Whether to fine-tune text
encoder. Defaults to False.
gradient_checkpointing (bool): Whether or not to use gradient
Expand All @@ -65,6 +66,7 @@ def __init__(
text_encoder_lora_config: dict | None = None,
prior_loss_weight: float = 1.,
data_preprocessor: dict | nn.Module | None = None,
vae_batch_size: int = 8,
*,
finetune_text_encoder: bool = False,
gradient_checkpointing: bool = False,
Expand Down Expand Up @@ -102,6 +104,7 @@ def __init__(
self.prior_loss_weight = prior_loss_weight
self.gradient_checkpointing = gradient_checkpointing
self.enable_xformers = enable_xformers
self.vae_batch_size = vae_batch_size

if not isinstance(loss, nn.Module):
loss = MODELS.build(
Expand Down Expand Up @@ -261,6 +264,17 @@ def test_step(
msg = "test_step is not implemented now, please use infer."
raise NotImplementedError(msg)

def _forward_vae(self, img: torch.Tensor, num_batches: int,
) -> torch.Tensor:
"""Forward vae."""
latents = []
for i in range(0, num_batches, self.vae_batch_size):
latents_ = self.vae.encode(img[i : i + self.vae_batch_size]).latents
latents_ = self.vae.quantize(latents_)[2][2].reshape(
num_batches, -1)
latents.append(latents_)
return torch.cat(latents, dim=0)

def forward(
self,
inputs: dict,
Expand Down Expand Up @@ -296,9 +310,7 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latents
latents = self.vae.quantize(latents)[2][2].reshape(
num_batches, -1)
latents = self._forward_vae(inputs["img"], num_batches)

timesteps = torch.rand(num_batches, device=self.device)

Expand Down
3 changes: 1 addition & 2 deletions diffengine/models/editors/distill_sd/distill_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
latents = self._forward_vae(inputs["img"], num_batches)

noise = self.noise_generator(latents)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
latents = self._forward_vae(inputs["img"], num_batches)

noise = self.noise_generator(latents)

Expand Down Expand Up @@ -230,7 +229,7 @@ def forward(
}

# condition
cond_latents = self.vae.encode(inputs["condition_img"]).latent_dist.sample()
cond_latents = self._forward_vae(inputs["condition_img"], num_batches)
# random zeros cond latents
mask = torch.multinomial(
torch.Tensor([
Expand Down
6 changes: 2 additions & 4 deletions diffengine/models/editors/ip_adapter/ip_adapter_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
latents = self._forward_vae(inputs["img"], num_batches)

noise = self.noise_generator(latents)

Expand Down Expand Up @@ -358,8 +357,7 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
latents = self._forward_vae(inputs["img"], num_batches)

noise = self.noise_generator(latents)

Expand Down
16 changes: 15 additions & 1 deletion diffengine/models/editors/kandinsky/kandinskyv22_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class KandinskyV22Decoder(BaseModel):
input_perturbation_gamma (float): The gamma of input perturbation.
The recommended value is 0.1 for Input Perturbation.
Defaults to 0.0.
vae_batch_size (int): The batch size of vae. Defaults to 8.
gradient_checkpointing (bool): Whether or not to use gradient
checkpointing to save memory at the expense of slower backward
pass. Defaults to False.
Expand All @@ -73,6 +74,7 @@ def __init__(
noise_generator: dict | None = None,
timesteps_generator: dict | None = None,
input_perturbation_gamma: float = 0.0,
vae_batch_size: int = 8,
*,
gradient_checkpointing: bool = False,
enable_xformers: bool = False,
Expand All @@ -93,6 +95,7 @@ def __init__(
self.gradient_checkpointing = gradient_checkpointing
self.input_perturbation_gamma = input_perturbation_gamma
self.enable_xformers = enable_xformers
self.vae_batch_size = vae_batch_size

if not isinstance(loss, nn.Module):
loss = MODELS.build(
Expand Down Expand Up @@ -303,6 +306,17 @@ def _preprocess_model_input(self,
input_noise = noise
return self.scheduler.add_noise(latents, input_noise, timesteps)

def _forward_vae(self, img: torch.Tensor, num_batches: int,
) -> torch.Tensor:
"""Forward vae."""
latents = [
self.vae.encode(
img[i : i + self.vae_batch_size],
).latents for i in range(
0, num_batches, self.vae_batch_size)
]
return torch.cat(latents, dim=0)

def forward(
self,
inputs: dict,
Expand Down Expand Up @@ -332,7 +346,7 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latents
latents = self._forward_vae(inputs["img"], num_batches)
image_embeds = self.image_encoder(inputs["clip_img"]).image_embeds

noise = self.noise_generator(latents)
Expand Down
16 changes: 15 additions & 1 deletion diffengine/models/editors/kandinsky/kandinskyv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class KandinskyV3(BaseModel):
input_perturbation_gamma (float): The gamma of input perturbation.
The recommended value is 0.1 for Input Perturbation.
Defaults to 0.0.
vae_batch_size (int): The batch size of vae. Defaults to 8.
gradient_checkpointing (bool): Whether or not to use gradient
checkpointing to save memory at the expense of slower backward
pass. Defaults to False.
Expand All @@ -75,6 +76,7 @@ def __init__(
noise_generator: dict | None = None,
timesteps_generator: dict | None = None,
input_perturbation_gamma: float = 0.0,
vae_batch_size: int = 8,
*,
gradient_checkpointing: bool = False,
enable_xformers: bool = False,
Expand All @@ -98,6 +100,7 @@ def __init__(
self.tokenizer_max_length = tokenizer_max_length
self.input_perturbation_gamma = input_perturbation_gamma
self.enable_xformers = enable_xformers
self.vae_batch_size = vae_batch_size

if not isinstance(loss, nn.Module):
loss = MODELS.build(
Expand Down Expand Up @@ -312,6 +315,17 @@ def _preprocess_model_input(self,
input_noise = noise
return self.scheduler.add_noise(latents, input_noise, timesteps)

def _forward_vae(self, img: torch.Tensor, num_batches: int,
) -> torch.Tensor:
"""Forward vae."""
latents = [
self.vae.encode(
img[i : i + self.vae_batch_size],
).latents.contiguous() for i in range(
0, num_batches, self.vae_batch_size)
]
return torch.cat(latents, dim=0)

def forward(
self,
inputs: dict,
Expand Down Expand Up @@ -349,7 +363,7 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latents.contiguous()
latents = self._forward_vae(inputs["img"], num_batches)

noise = self.noise_generator(latents)

Expand Down
3 changes: 1 addition & 2 deletions diffengine/models/editors/lcm/lcm_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
latents = self._forward_vae(inputs["img"], num_batches)

noise = self.noise_generator(latents)

Expand Down
18 changes: 16 additions & 2 deletions diffengine/models/editors/pixart_alpha/pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class PixArtAlpha(BaseModel):
input_perturbation_gamma (float): The gamma of input perturbation.
The recommended value is 0.1 for Input Perturbation.
Defaults to 0.0.
vae_batch_size (int): The batch size of vae. Defaults to 8.
finetune_text_encoder (bool, optional): Whether to fine-tune text
encoder. Defaults to False.
gradient_checkpointing (bool): Whether or not to use gradient
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
noise_generator: dict | None = None,
timesteps_generator: dict | None = None,
input_perturbation_gamma: float = 0.0,
vae_batch_size: int = 8,
*,
finetune_text_encoder: bool = False,
gradient_checkpointing: bool = False,
Expand Down Expand Up @@ -126,6 +128,7 @@ def __init__(
self.gradient_checkpointing = gradient_checkpointing
self.input_perturbation_gamma = input_perturbation_gamma
self.enable_xformers = enable_xformers
self.vae_batch_size = vae_batch_size

if not isinstance(loss, nn.Module):
loss = MODELS.build(
Expand Down Expand Up @@ -353,6 +356,18 @@ def _preprocess_model_input(self,
input_noise = noise
return self.scheduler.add_noise(latents, input_noise, timesteps)

def _forward_vae(self, img: torch.Tensor, num_batches: int,
) -> torch.Tensor:
"""Forward vae."""
latents = [
self.vae.encode(
img[i : i + self.vae_batch_size],
).latent_dist.sample() for i in range(
0, num_batches, self.vae_batch_size)
]
latents = torch.cat(latents, dim=0)
return latents * self.vae.config.scaling_factor

def forward(
self,
inputs: dict,
Expand Down Expand Up @@ -390,8 +405,7 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
latents = self._forward_vae(inputs["img"], num_batches)

noise = self.noise_generator(latents)

Expand Down
18 changes: 16 additions & 2 deletions diffengine/models/editors/ssd_1b/ssd_1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class SSD1B(StableDiffusionXL):
input_perturbation_gamma (float): The gamma of input perturbation.
The recommended value is 0.1 for Input Perturbation.
Defaults to 0.0.
vae_batch_size (int): The batch size of vae. Defaults to 8.
finetune_text_encoder (bool, optional): Whether to fine-tune text
encoder. Defaults to False.
gradient_checkpointing (bool): Whether or not to use gradient
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
noise_generator: dict | None = None,
timesteps_generator: dict | None = None,
input_perturbation_gamma: float = 0.0,
vae_batch_size: int = 8,
*,
finetune_text_encoder: bool = False,
gradient_checkpointing: bool = False,
Expand Down Expand Up @@ -122,6 +124,7 @@ def __init__(
self.pre_compute_text_embeddings = pre_compute_text_embeddings
self.input_perturbation_gamma = input_perturbation_gamma
self.enable_xformers = enable_xformers
self.vae_batch_size = vae_batch_size

if not isinstance(loss, nn.Module):
loss = MODELS.build(
Expand Down Expand Up @@ -264,6 +267,18 @@ def set_xformers(self) -> None:
msg,
)

def _forward_vae(self, img: torch.Tensor, num_batches: int,
) -> torch.Tensor:
"""Forward vae."""
latents = [
self.vae.encode(
img[i : i + self.vae_batch_size],
).latent_dist.sample() for i in range(
0, num_batches, self.vae_batch_size)
]
latents = torch.cat(latents, dim=0)
return latents * self.vae.config.scaling_factor

def forward(
self,
inputs: dict,
Expand Down Expand Up @@ -293,8 +308,7 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
latents = self._forward_vae(inputs["img"], num_batches)

noise = self.noise_generator(latents)

Expand Down
18 changes: 16 additions & 2 deletions diffengine/models/editors/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class StableDiffusion(BaseModel):
input_perturbation_gamma (float): The gamma of input perturbation.
The recommended value is 0.1 for Input Perturbation.
Defaults to 0.0.
vae_batch_size (int): The batch size of vae. Defaults to 8.
finetune_text_encoder (bool, optional): Whether to fine-tune text
encoder. Defaults to False.
gradient_checkpointing (bool): Whether or not to use gradient
Expand All @@ -79,6 +80,7 @@ def __init__(
noise_generator: dict | None = None,
timesteps_generator: dict | None = None,
input_perturbation_gamma: float = 0.0,
vae_batch_size: int = 8,
*,
finetune_text_encoder: bool = False,
gradient_checkpointing: bool = False,
Expand Down Expand Up @@ -121,6 +123,7 @@ def __init__(
self.gradient_checkpointing = gradient_checkpointing
self.input_perturbation_gamma = input_perturbation_gamma
self.enable_xformers = enable_xformers
self.vae_batch_size = vae_batch_size

if not isinstance(loss, nn.Module):
loss = MODELS.build(
Expand Down Expand Up @@ -336,6 +339,18 @@ def _preprocess_model_input(self,
input_noise = noise
return self.scheduler.add_noise(latents, input_noise, timesteps)

def _forward_vae(self, img: torch.Tensor, num_batches: int,
) -> torch.Tensor:
"""Forward vae."""
latents = [
self.vae.encode(
img[i : i + self.vae_batch_size],
).latent_dist.sample() for i in range(
0, num_batches, self.vae_batch_size)
]
latents = torch.cat(latents, dim=0)
return latents * self.vae.config.scaling_factor

def forward(
self,
inputs: dict,
Expand Down Expand Up @@ -371,8 +386,7 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
latents = self._forward_vae(inputs["img"], num_batches)

noise = self.noise_generator(latents)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
latents = self._forward_vae(inputs["img"], num_batches)

noise = self.noise_generator(latents)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,8 @@ def forward(
else:
weight = None

latents = self.vae.encode(inputs["img"]).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor

masked_latents = self.vae.encode(inputs["masked_image"]).latent_dist.sample()
masked_latents = masked_latents * self.vae.config.scaling_factor
latents = self._forward_vae(inputs["img"], num_batches)
masked_latents = self._forward_vae(inputs["masked_image"], num_batches)

mask = F.interpolate(inputs["mask"],
size=(latents.shape[2], latents.shape[3]))
Expand Down
Loading

0 comments on commit 1f9006f

Please sign in to comment.