Skip to content

Commit

Permalink
[LoRA SD3] add support for lora fusion in sd3 (huggingface#8616)
Browse files Browse the repository at this point in the history
* add support for lora fusion in sd3

* add test to ensure fused lora and effective lora produce same outpouts
  • Loading branch information
sayakpaul committed Jun 20, 2024
1 parent 25d7bb3 commit 668e34c
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 0 deletions.
75 changes: 75 additions & 0 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,3 +1728,78 @@ def _optionally_disable_offloading(cls, _pipeline):
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)

return (is_model_cpu_offload, is_sequential_cpu_offload)

def fuse_lora(
self,
fuse_transformer: bool = True,
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
fuse_transformer (`bool`, defaults to `True`): Whether to fuse the transformer LoRA parameters.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"nerijs/pixel-art-medium-128-v0.1",
weight_name="pixel-art-medium-128-v0.1.safetensors",
adapter_name="pixel",
)
pipeline.fuse_lora(lora_scale=0.7)
```
"""
if fuse_transformer:
self.num_fused_loras += 1

if fuse_transformer:
transformer = (
getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
)
transformer.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)

def unfuse_lora(self, unfuse_transformer: bool = True):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the transformer LoRA parameters.
"""
from peft.tuners.tuners_utils import BaseTunerLayer

transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
if unfuse_transformer:
for module in transformer.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()

self.num_fused_loras -= 1
43 changes: 43 additions & 0 deletions src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


import inspect
from functools import partial
from typing import Any, Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -239,6 +241,47 @@ def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `fuse_lora()`.")

self.lora_scale = lora_scale
self._safe_fusing = safe_fusing
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))

def _fuse_lora_apply(self, module, adapter_names=None):
from peft.tuners.tuners_utils import BaseTunerLayer

merge_kwargs = {"safe_merge": self._safe_fusing}

if isinstance(module, BaseTunerLayer):
if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale)

# For BC with prevous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
" to the latest version of PEFT. `pip install -U peft`"
)

module.merge(**merge_kwargs)

def unfuse_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `unfuse_lora()`.")
self.apply(self._unfuse_lora_apply)

def _unfuse_lora_apply(self, module):
from peft.tuners.tuners_utils import BaseTunerLayer

if isinstance(module, BaseTunerLayer):
module.unmerge()

def forward(
self,
hidden_states: torch.FloatTensor,
Expand Down
82 changes: 82 additions & 0 deletions tests/lora/test_lora_layers_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,85 @@ def test_simple_inference_with_transformer_lora_and_scale(self):
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)

def test_simple_inference_with_transformer_fused(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images

pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")

pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")

inputs = self.get_dummy_inputs(torch_device)
ouput_fused = pipe(**inputs).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)

def test_simple_inference_with_transformer_fused_with_no_fusion(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images

pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_lora = pipe(**inputs).images

pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")

inputs = self.get_dummy_inputs(torch_device)
ouput_fused = pipe(**inputs).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
self.assertTrue(
np.allclose(ouput_fused, ouput_lora, atol=1e-3, rtol=1e-3),
"Fused lora output should be changed when LoRA isn't fused but still effective.",
)

def test_simple_inference_with_transformer_fuse_unfuse(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images

pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")

pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_fused = pipe(**inputs).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)

pipe.unfuse_lora()
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
output_unfused_lora = pipe(**inputs).images
self.assertTrue(
np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)

0 comments on commit 668e34c

Please sign in to comment.