diff --git a/src/backend/pipelines/lcm_lora.py b/src/backend/pipelines/lcm_lora.py index 9bd6616..093a9cc 100644 --- a/src/backend/pipelines/lcm_lora.py +++ b/src/backend/pipelines/lcm_lora.py @@ -1,5 +1,11 @@ import torch -from diffusers import DiffusionPipeline, LCMScheduler, AutoPipelineForText2Image +from diffusers import ( + DiffusionPipeline, + LCMScheduler, + AutoPipelineForText2Image, + StableDiffusionPipeline, +) +import pathlib def load_lcm_weights( @@ -25,23 +31,44 @@ def get_lcm_lora_pipeline( torch_data_type: torch.dtype, pipeline_args={}, ): - # pipeline = DiffusionPipeline.from_pretrained( - pipeline = AutoPipelineForText2Image.from_pretrained( - base_model_id, - torch_dtype=torch_data_type, - local_files_only=use_local_model, - **pipeline_args, - ) + if pathlib.Path(base_model_id).suffix == ".safetensors": + # When loading a .safetensors model, the pipeline has to be created + # with StableDiffusionPipeline() since it's the only class that + # defines the method from_single_file(); afterwards a new pipeline + # is created using AutoPipelineForText2Image() for ControlNet + # support, in case ControlNet is enabled + dummy_pipeline = StableDiffusionPipeline.from_single_file( + base_model_id, + torch_dtype=torch_data_type, + safety_checker=None, + load_safety_checker=False, + local_files_only=use_local_model, + use_safetensors=True, + ) + pipeline = AutoPipelineForText2Image.from_pipe( + dummy_pipeline, + **pipeline_args, + ) + del dummy_pipeline + else: + pipeline = AutoPipelineForText2Image.from_pretrained( + base_model_id, + torch_dtype=torch_data_type, + local_files_only=use_local_model, + **pipeline_args, + ) load_lcm_weights( pipeline, use_local_model, lcm_lora_id, ) + # Always fuse LCM-LoRA + pipeline.fuse_lora() if "lcm" in lcm_lora_id.lower(): print("LCM LoRA model detected so using recommended LCMScheduler") pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config) - pipeline.unet.to(memory_format=torch.channels_last) + # pipeline.unet.to(memory_format=torch.channels_last) return pipeline