diff --git a/diffengine/models/editors/amused/amused.py b/diffengine/models/editors/amused/amused.py index a56addd..0fd6193 100644 --- a/diffengine/models/editors/amused/amused.py +++ b/diffengine/models/editors/amused/amused.py @@ -44,6 +44,7 @@ class AMUSEd(BaseModel): It works when training dreambooth with class images. data_preprocessor (dict, optional): The pre-process config of :class:`SDDataPreprocessor`. + vae_batch_size (int): The batch size of vae. Defaults to 8. finetune_text_encoder (bool, optional): Whether to fine-tune text encoder. Defaults to False. gradient_checkpointing (bool): Whether or not to use gradient @@ -65,6 +66,7 @@ def __init__( text_encoder_lora_config: dict | None = None, prior_loss_weight: float = 1., data_preprocessor: dict | nn.Module | None = None, + vae_batch_size: int = 8, *, finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, @@ -102,6 +104,7 @@ def __init__( self.prior_loss_weight = prior_loss_weight self.gradient_checkpointing = gradient_checkpointing self.enable_xformers = enable_xformers + self.vae_batch_size = vae_batch_size if not isinstance(loss, nn.Module): loss = MODELS.build( @@ -261,6 +264,17 @@ def test_step( msg = "test_step is not implemented now, please use infer." raise NotImplementedError(msg) + def _forward_vae(self, img: torch.Tensor, num_batches: int, + ) -> torch.Tensor: + """Forward vae.""" + latents = [] + for i in range(0, num_batches, self.vae_batch_size): + latents_ = self.vae.encode(img[i : i + self.vae_batch_size]).latents + latents_ = self.vae.quantize(latents_)[2][2].reshape( + num_batches, -1) + latents.append(latents_) + return torch.cat(latents, dim=0) + def forward( self, inputs: dict, @@ -296,9 +310,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latents - latents = self.vae.quantize(latents)[2][2].reshape( - num_batches, -1) + latents = self._forward_vae(inputs["img"], num_batches) timesteps = torch.rand(num_batches, device=self.device) diff --git a/diffengine/models/editors/distill_sd/distill_sd_xl.py b/diffengine/models/editors/distill_sd/distill_sd_xl.py index 1bd06e1..6685506 100644 --- a/diffengine/models/editors/distill_sd/distill_sd_xl.py +++ b/diffengine/models/editors/distill_sd/distill_sd_xl.py @@ -188,8 +188,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) diff --git a/diffengine/models/editors/instruct_pix2pix/instruct_pix2pix_xl.py b/diffengine/models/editors/instruct_pix2pix/instruct_pix2pix_xl.py index fa03bc0..af9f342 100644 --- a/diffengine/models/editors/instruct_pix2pix/instruct_pix2pix_xl.py +++ b/diffengine/models/editors/instruct_pix2pix/instruct_pix2pix_xl.py @@ -196,8 +196,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) @@ -230,7 +229,7 @@ def forward( } # condition - cond_latents = self.vae.encode(inputs["condition_img"]).latent_dist.sample() + cond_latents = self._forward_vae(inputs["condition_img"], num_batches) # random zeros cond latents mask = torch.multinomial( torch.Tensor([ diff --git a/diffengine/models/editors/ip_adapter/ip_adapter_xl.py b/diffengine/models/editors/ip_adapter/ip_adapter_xl.py index d4e78c3..b971399 100644 --- a/diffengine/models/editors/ip_adapter/ip_adapter_xl.py +++ b/diffengine/models/editors/ip_adapter/ip_adapter_xl.py @@ -258,8 +258,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) @@ -358,8 +357,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) diff --git a/diffengine/models/editors/kandinsky/kandinskyv22_decoder.py b/diffengine/models/editors/kandinsky/kandinskyv22_decoder.py index a82b946..6643f66 100644 --- a/diffengine/models/editors/kandinsky/kandinskyv22_decoder.py +++ b/diffengine/models/editors/kandinsky/kandinskyv22_decoder.py @@ -50,6 +50,7 @@ class KandinskyV22Decoder(BaseModel): input_perturbation_gamma (float): The gamma of input perturbation. The recommended value is 0.1 for Input Perturbation. Defaults to 0.0. + vae_batch_size (int): The batch size of vae. Defaults to 8. gradient_checkpointing (bool): Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. Defaults to False. @@ -73,6 +74,7 @@ def __init__( noise_generator: dict | None = None, timesteps_generator: dict | None = None, input_perturbation_gamma: float = 0.0, + vae_batch_size: int = 8, *, gradient_checkpointing: bool = False, enable_xformers: bool = False, @@ -93,6 +95,7 @@ def __init__( self.gradient_checkpointing = gradient_checkpointing self.input_perturbation_gamma = input_perturbation_gamma self.enable_xformers = enable_xformers + self.vae_batch_size = vae_batch_size if not isinstance(loss, nn.Module): loss = MODELS.build( @@ -303,6 +306,17 @@ def _preprocess_model_input(self, input_noise = noise return self.scheduler.add_noise(latents, input_noise, timesteps) + def _forward_vae(self, img: torch.Tensor, num_batches: int, + ) -> torch.Tensor: + """Forward vae.""" + latents = [ + self.vae.encode( + img[i : i + self.vae_batch_size], + ).latents for i in range( + 0, num_batches, self.vae_batch_size) + ] + return torch.cat(latents, dim=0) + def forward( self, inputs: dict, @@ -332,7 +346,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latents + latents = self._forward_vae(inputs["img"], num_batches) image_embeds = self.image_encoder(inputs["clip_img"]).image_embeds noise = self.noise_generator(latents) diff --git a/diffengine/models/editors/kandinsky/kandinskyv3.py b/diffengine/models/editors/kandinsky/kandinskyv3.py index 8195aab..9adec32 100644 --- a/diffengine/models/editors/kandinsky/kandinskyv3.py +++ b/diffengine/models/editors/kandinsky/kandinskyv3.py @@ -51,6 +51,7 @@ class KandinskyV3(BaseModel): input_perturbation_gamma (float): The gamma of input perturbation. The recommended value is 0.1 for Input Perturbation. Defaults to 0.0. + vae_batch_size (int): The batch size of vae. Defaults to 8. gradient_checkpointing (bool): Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. Defaults to False. @@ -75,6 +76,7 @@ def __init__( noise_generator: dict | None = None, timesteps_generator: dict | None = None, input_perturbation_gamma: float = 0.0, + vae_batch_size: int = 8, *, gradient_checkpointing: bool = False, enable_xformers: bool = False, @@ -98,6 +100,7 @@ def __init__( self.tokenizer_max_length = tokenizer_max_length self.input_perturbation_gamma = input_perturbation_gamma self.enable_xformers = enable_xformers + self.vae_batch_size = vae_batch_size if not isinstance(loss, nn.Module): loss = MODELS.build( @@ -312,6 +315,17 @@ def _preprocess_model_input(self, input_noise = noise return self.scheduler.add_noise(latents, input_noise, timesteps) + def _forward_vae(self, img: torch.Tensor, num_batches: int, + ) -> torch.Tensor: + """Forward vae.""" + latents = [ + self.vae.encode( + img[i : i + self.vae_batch_size], + ).latents.contiguous() for i in range( + 0, num_batches, self.vae_batch_size) + ] + return torch.cat(latents, dim=0) + def forward( self, inputs: dict, @@ -349,7 +363,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latents.contiguous() + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) diff --git a/diffengine/models/editors/lcm/lcm_xl.py b/diffengine/models/editors/lcm/lcm_xl.py index b8d6ae0..b03fd2b 100644 --- a/diffengine/models/editors/lcm/lcm_xl.py +++ b/diffengine/models/editors/lcm/lcm_xl.py @@ -253,8 +253,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) diff --git a/diffengine/models/editors/pixart_alpha/pixart_alpha.py b/diffengine/models/editors/pixart_alpha/pixart_alpha.py index 0363222..8f35808 100644 --- a/diffengine/models/editors/pixart_alpha/pixart_alpha.py +++ b/diffengine/models/editors/pixart_alpha/pixart_alpha.py @@ -56,6 +56,7 @@ class PixArtAlpha(BaseModel): input_perturbation_gamma (float): The gamma of input perturbation. The recommended value is 0.1 for Input Perturbation. Defaults to 0.0. + vae_batch_size (int): The batch size of vae. Defaults to 8. finetune_text_encoder (bool, optional): Whether to fine-tune text encoder. Defaults to False. gradient_checkpointing (bool): Whether or not to use gradient @@ -83,6 +84,7 @@ def __init__( noise_generator: dict | None = None, timesteps_generator: dict | None = None, input_perturbation_gamma: float = 0.0, + vae_batch_size: int = 8, *, finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, @@ -126,6 +128,7 @@ def __init__( self.gradient_checkpointing = gradient_checkpointing self.input_perturbation_gamma = input_perturbation_gamma self.enable_xformers = enable_xformers + self.vae_batch_size = vae_batch_size if not isinstance(loss, nn.Module): loss = MODELS.build( @@ -353,6 +356,18 @@ def _preprocess_model_input(self, input_noise = noise return self.scheduler.add_noise(latents, input_noise, timesteps) + def _forward_vae(self, img: torch.Tensor, num_batches: int, + ) -> torch.Tensor: + """Forward vae.""" + latents = [ + self.vae.encode( + img[i : i + self.vae_batch_size], + ).latent_dist.sample() for i in range( + 0, num_batches, self.vae_batch_size) + ] + latents = torch.cat(latents, dim=0) + return latents * self.vae.config.scaling_factor + def forward( self, inputs: dict, @@ -390,8 +405,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) diff --git a/diffengine/models/editors/ssd_1b/ssd_1b.py b/diffengine/models/editors/ssd_1b/ssd_1b.py index 7645c1d..e1a1384 100644 --- a/diffengine/models/editors/ssd_1b/ssd_1b.py +++ b/diffengine/models/editors/ssd_1b/ssd_1b.py @@ -59,6 +59,7 @@ class SSD1B(StableDiffusionXL): input_perturbation_gamma (float): The gamma of input perturbation. The recommended value is 0.1 for Input Perturbation. Defaults to 0.0. + vae_batch_size (int): The batch size of vae. Defaults to 8. finetune_text_encoder (bool, optional): Whether to fine-tune text encoder. Defaults to False. gradient_checkpointing (bool): Whether or not to use gradient @@ -92,6 +93,7 @@ def __init__( noise_generator: dict | None = None, timesteps_generator: dict | None = None, input_perturbation_gamma: float = 0.0, + vae_batch_size: int = 8, *, finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, @@ -122,6 +124,7 @@ def __init__( self.pre_compute_text_embeddings = pre_compute_text_embeddings self.input_perturbation_gamma = input_perturbation_gamma self.enable_xformers = enable_xformers + self.vae_batch_size = vae_batch_size if not isinstance(loss, nn.Module): loss = MODELS.build( @@ -264,6 +267,18 @@ def set_xformers(self) -> None: msg, ) + def _forward_vae(self, img: torch.Tensor, num_batches: int, + ) -> torch.Tensor: + """Forward vae.""" + latents = [ + self.vae.encode( + img[i : i + self.vae_batch_size], + ).latent_dist.sample() for i in range( + 0, num_batches, self.vae_batch_size) + ] + latents = torch.cat(latents, dim=0) + return latents * self.vae.config.scaling_factor + def forward( self, inputs: dict, @@ -293,8 +308,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) diff --git a/diffengine/models/editors/stable_diffusion/stable_diffusion.py b/diffengine/models/editors/stable_diffusion/stable_diffusion.py index 9af460d..ee762f2 100644 --- a/diffengine/models/editors/stable_diffusion/stable_diffusion.py +++ b/diffengine/models/editors/stable_diffusion/stable_diffusion.py @@ -53,6 +53,7 @@ class StableDiffusion(BaseModel): input_perturbation_gamma (float): The gamma of input perturbation. The recommended value is 0.1 for Input Perturbation. Defaults to 0.0. + vae_batch_size (int): The batch size of vae. Defaults to 8. finetune_text_encoder (bool, optional): Whether to fine-tune text encoder. Defaults to False. gradient_checkpointing (bool): Whether or not to use gradient @@ -79,6 +80,7 @@ def __init__( noise_generator: dict | None = None, timesteps_generator: dict | None = None, input_perturbation_gamma: float = 0.0, + vae_batch_size: int = 8, *, finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, @@ -121,6 +123,7 @@ def __init__( self.gradient_checkpointing = gradient_checkpointing self.input_perturbation_gamma = input_perturbation_gamma self.enable_xformers = enable_xformers + self.vae_batch_size = vae_batch_size if not isinstance(loss, nn.Module): loss = MODELS.build( @@ -336,6 +339,18 @@ def _preprocess_model_input(self, input_noise = noise return self.scheduler.add_noise(latents, input_noise, timesteps) + def _forward_vae(self, img: torch.Tensor, num_batches: int, + ) -> torch.Tensor: + """Forward vae.""" + latents = [ + self.vae.encode( + img[i : i + self.vae_batch_size], + ).latent_dist.sample() for i in range( + 0, num_batches, self.vae_batch_size) + ] + latents = torch.cat(latents, dim=0) + return latents * self.vae.config.scaling_factor + def forward( self, inputs: dict, @@ -371,8 +386,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) diff --git a/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py b/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py index 70a84ef..6cf3a0f 100644 --- a/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py +++ b/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py @@ -253,8 +253,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) diff --git a/diffengine/models/editors/stable_diffusion_inpaint/stable_diffusion_inpaint.py b/diffengine/models/editors/stable_diffusion_inpaint/stable_diffusion_inpaint.py index 86139d8..a71033d 100644 --- a/diffengine/models/editors/stable_diffusion_inpaint/stable_diffusion_inpaint.py +++ b/diffengine/models/editors/stable_diffusion_inpaint/stable_diffusion_inpaint.py @@ -182,11 +182,8 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor - - masked_latents = self.vae.encode(inputs["masked_image"]).latent_dist.sample() - masked_latents = masked_latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) + masked_latents = self._forward_vae(inputs["masked_image"], num_batches) mask = F.interpolate(inputs["mask"], size=(latents.shape[2], latents.shape[3])) diff --git a/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py b/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py index cd5ea5c..f4779bd 100644 --- a/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py +++ b/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py @@ -58,6 +58,7 @@ class StableDiffusionXL(BaseModel): input_perturbation_gamma (float): The gamma of input perturbation. The recommended value is 0.1 for Input Perturbation. Defaults to 0.0. + vae_batch_size (int): The batch size of vae. Defaults to 8. finetune_text_encoder (bool, optional): Whether to fine-tune text encoder. Defaults to False. gradient_checkpointing (bool): Whether or not to use gradient @@ -88,6 +89,7 @@ def __init__( # noqa: C901 noise_generator: dict | None = None, timesteps_generator: dict | None = None, input_perturbation_gamma: float = 0.0, + vae_batch_size: int = 8, *, finetune_text_encoder: bool = False, gradient_checkpointing: bool = False, @@ -135,6 +137,7 @@ def __init__( # noqa: C901 self.pre_compute_text_embeddings = pre_compute_text_embeddings self.input_perturbation_gamma = input_perturbation_gamma self.enable_xformers = enable_xformers + self.vae_batch_size = vae_batch_size if not isinstance(loss, nn.Module): loss = MODELS.build( @@ -419,6 +422,18 @@ def _preprocess_model_input(self, input_noise = noise return self.scheduler.add_noise(latents, input_noise, timesteps) + def _forward_vae(self, img: torch.Tensor, num_batches: int, + ) -> torch.Tensor: + """Forward vae.""" + latents = [ + self.vae.encode( + img[i : i + self.vae_batch_size], + ).latent_dist.sample() for i in range( + 0, num_batches, self.vae_batch_size) + ] + latents = torch.cat(latents, dim=0) + return latents * self.vae.config.scaling_factor + def forward( self, inputs: dict, @@ -448,8 +463,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) diff --git a/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py b/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py index 41912fa..ea5a8b0 100644 --- a/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py +++ b/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py @@ -266,8 +266,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents) diff --git a/diffengine/models/editors/stable_diffusion_xl_dpo/stable_diffusion_xl_dpo.py b/diffengine/models/editors/stable_diffusion_xl_dpo/stable_diffusion_xl_dpo.py index 9810c76..d502d75 100644 --- a/diffengine/models/editors/stable_diffusion_xl_dpo/stable_diffusion_xl_dpo.py +++ b/diffengine/models/editors/stable_diffusion_xl_dpo/stable_diffusion_xl_dpo.py @@ -135,8 +135,7 @@ def forward( # num_batches is divided by 2 because we have two images per sample num_batches = len(inputs["img"]) // 2 - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents[:num_batches]) # repeat noise for each sample set diff --git a/diffengine/models/editors/stable_diffusion_xl_inpaint/stable_diffusion_xl_inpaint.py b/diffengine/models/editors/stable_diffusion_xl_inpaint/stable_diffusion_xl_inpaint.py index 04b97a3..407a62a 100644 --- a/diffengine/models/editors/stable_diffusion_xl_inpaint/stable_diffusion_xl_inpaint.py +++ b/diffengine/models/editors/stable_diffusion_xl_inpaint/stable_diffusion_xl_inpaint.py @@ -195,11 +195,8 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor - - masked_latents = self.vae.encode(inputs["masked_image"]).latent_dist.sample() - masked_latents = masked_latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) + masked_latents = self._forward_vae(inputs["masked_image"], num_batches) mask = F.interpolate(inputs["mask"], size=(latents.shape[2], latents.shape[3])) diff --git a/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py b/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py index 4166d8b..895b309 100644 --- a/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py +++ b/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py @@ -235,8 +235,7 @@ def forward( else: weight = None - latents = self.vae.encode(inputs["img"]).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor + latents = self._forward_vae(inputs["img"], num_batches) noise = self.noise_generator(latents)