Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kwangjae Park authored and Kwangjae Park committed Jul 8, 2024
1 parent 963ee05 commit 5a1b1c2
Show file tree
Hide file tree
Showing 1,330 changed files with 727,075 additions and 31 deletions.
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
is_torch_available,
is_torchsde_available,
is_transformers_available,
is_coremltools_available,
)


Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_LazyModule,
is_flax_available,
is_torch_available,
is_coremltools_available,
)


Expand Down
9 changes: 9 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,18 @@
import safetensors
import torch
from huggingface_hub.utils import EntryNotFoundError
import coremltools as ct

from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
COREML_COMPILED_FILE_EXTENSION,
WEIGHTS_INDEX_NAME,
_add_variant,
_get_model_file,
is_accelerate_available,
is_torch_version,
is_coremltools_available,
logging,
)

Expand Down Expand Up @@ -103,6 +106,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
elif file_extension == COREML_COMPILED_FILE_EXTENSION:
return ct.models.CompiledMLModel(checkpoint_file, ct.ComputeUnit.CPU_AND_GPU)
else:
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
Expand Down Expand Up @@ -163,6 +168,10 @@ def load_model_dict_into_meta(


def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
if is_coremltools_available():
model_to_load._state_dict = state_dict
return ""

# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy()
Expand Down
84 changes: 54 additions & 30 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
COREML_COMPILED_NAME,
COREML_COMPILED_FILE_EXTENSION,
_add_variant,
_get_checkpoint_shard_files,
_get_model_file,
deprecate,
is_accelerate_available,
is_torch_version,
is_coremltools_available,
logging,
)
from ..utils.hub_utils import (
Expand Down Expand Up @@ -128,6 +131,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None
_no_split_modules = None
if is_coremltools_available():
_state_dict = None
_coreml_type = None

def __init__(self):
super().__init__()
Expand All @@ -150,6 +156,12 @@ def __getattr__(self, name: str) -> Any:
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
return super().__getattr__(name)

def state_dict(self):
if is_coremltools_available():
return self._state_dict
else:
return super().state_dict()

@property
def is_gradient_checkpointing(self) -> bool:
"""
Expand Down Expand Up @@ -536,7 +548,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
if is_coremltools_available():
use_safetensors = False
else:
use_safetensors = True
allow_pickle = True

if low_cpu_mem_usage and not is_accelerate_available():
Expand Down Expand Up @@ -721,7 +736,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if model_file is None and not is_sharded:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
weights_name=_add_variant(COREML_COMPILED_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
Expand All @@ -743,34 +758,36 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if device_map is None and not is_sharded:
param_device = "cpu"
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)

unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_name_or_path,
)

if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if not model_file.endswith(COREML_COMPILED_FILE_EXTENSION):
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_name_or_path,
)

if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)

else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
Expand Down Expand Up @@ -836,7 +853,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model = cls.from_config(config, **unused_kwargs)

state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
if model_file.endswith(COREML_COMPILED_FILE_EXTENSION):
model._coreml_type = "compiled"
else:
model._convert_deprecated_attention_blocks(state_dict)

model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
Expand Down Expand Up @@ -864,6 +884,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()

if output_loading_info:
return model, loading_info

Expand All @@ -880,9 +901,12 @@ def _load_pretrained_model(
):
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
loaded_keys = list(state_dict.keys())

expected_keys = list(model_state_dict.keys())
if is_coremltools_available():
loaded_keys = []
expected_keys = []
else:
loaded_keys = list(state_dict.keys())
expected_keys = list(model_state_dict.keys())

original_loaded_keys = loaded_keys

Expand Down
18 changes: 18 additions & 0 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint
import numpy as np

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
Expand Down Expand Up @@ -1093,6 +1094,23 @@ def forward(
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
if hasattr(self, "_coreml_type"):
if self._coreml_type == "compiled":
encoder_hidden_states = np.expand_dims(encoder_hidden_states.permute((0, 2, 1)), axis=2)
kwargs = {
"sample": sample.numpy(),
"timestep": np.array([timestep.numpy(), timestep.numpy()]),
"encoder_hidden_states": encoder_hidden_states,
"text_embeds": added_cond_kwargs["text_embeds"].numpy(),
"time_ids": added_cond_kwargs["time_ids"].numpy()
}

sample = torch.from_numpy(self._state_dict.predict(kwargs)["noise_pred"])
if not return_dict:
return (sample,)

return UNet2DConditionOutput(sample=sample)

default_overall_up_factor = 2**self.num_upsamplers

# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,10 @@ def encode_prompt(
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
if len(prompt_embeds.hidden_states) == 1:
prompt_embeds = prompt_embeds.hidden_states[0]
else:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
USE_PEFT_BACKEND,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
COREML_COMPILED_NAME,
COREML_COMPILED_FILE_EXTENSION
)
from .deprecation_utils import deprecate
from .doc_utils import replace_example_docstring
Expand Down Expand Up @@ -92,6 +94,7 @@
is_unidecode_available,
is_wandb_available,
is_xformers_available,
is_coremltools_available,
requires_backends,
)
from .loading_utils import load_image
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
COREML_COMPILED_NAME = "model.mlmodelc"
COREML_COMPILED_FILE_EXTENSION = "mlmodelc"

# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ def _get_model_file(
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
return model_file
elif os.path.isdir(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a CoreML checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
return model_file
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
Expand Down
17 changes: 17 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,13 @@ def is_timm_available():

_is_google_colab = "google.colab" in sys.modules

_coremltools_available = importlib.util.find_spec("coremltools") is not None
try:
_coremltools_version = importlib_metadata.version("coremltools")
logger.debug(f"Successfully imported coremltools version {_coremltools_version}")
except importlib_metadata.PackageNotFoundError:
_coremltools_available = False


def is_torch_available():
return _torch_available
Expand Down Expand Up @@ -451,6 +458,10 @@ def is_google_colab():
return _is_google_colab


def is_coremltools_available():
return _coremltools_available


# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
Expand Down Expand Up @@ -573,6 +584,11 @@ def is_google_colab():
{0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes`
"""

# docstyle-ignore
COREMLTOOLS_IMPORT_ERROR = """
{0} requires the coremltools library but it was not found in your environment. You can install it with pip: `pip install coremltools`
"""

BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
Expand All @@ -596,6 +612,7 @@ def is_google_colab():
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)),
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
("coremltools", (is_coremltools_available, COREMLTOOLS_IMPORT_ERROR)),
]
)

Expand Down
Loading

0 comments on commit 5a1b1c2

Please sign in to comment.