Skip to content

Commit

Permalink
Fixed bug in txt2vid not allowing you to use custom models.
Browse files Browse the repository at this point in the history
Fixed the script that convert ckpt models to diffusers not working with some models if they had invalid or broken keys, these keys are now ignored so we can still convert the model.
  • Loading branch information
ZeroCool940711 committed Nov 24, 2022
1 parent 54e7da7 commit 0d73421
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 18 deletions.
11 changes: 7 additions & 4 deletions scripts/convert_original_stable_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
#DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LDMTextToImagePipeline,
Expand Down Expand Up @@ -627,8 +627,11 @@ def convert_ldm_clip_checkpoint(checkpoint):
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]

text_model.load_state_dict(text_model_dict)

try:
text_model.load_state_dict(text_model_dict)
except RuntimeError:
pass

return text_model

Expand Down Expand Up @@ -748,4 +751,4 @@ def convert_ldm_clip_checkpoint(checkpoint):
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)

pipe.save_pretrained(args.dump_path)
pipe.save_pretrained(args.dump_path)
55 changes: 41 additions & 14 deletions scripts/txt2vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
#from stable_diffusion_videos import StableDiffusionWalkPipeline

from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \
PNDMScheduler
PNDMScheduler, DDPMScheduler, FlaxPNDMScheduler, FlaxDDIMScheduler, \
FlaxDDPMScheduler, FlaxKarrasVeScheduler, IPNDMScheduler, KarrasVeScheduler

from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
Expand Down Expand Up @@ -1125,28 +1126,30 @@ def load_diffusers_model(weights_path,torch_device):

if weights_path == "runwayml/stable-diffusion-v1-5":
model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5")
else:
model_path = weights_path

if not os.path.exists(model_path + "/model_index.json"):
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=st.session_state["defaults"].general.huggingface_token,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
use_local_file=True,
use_auth_token=st.session_state["defaults"].general.huggingface_token,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",

)

StableDiffusionPipeline.save_pretrained(server_state["pipe"], model_path)
else:
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
model_path,
use_local_file=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
use_local_file=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
)

server_state["pipe"].unet.to(torch_device)
Expand Down Expand Up @@ -1358,8 +1361,30 @@ def txt2vid(
klms_scheduler = LMSDiscreteScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
)

#flaxddims_scheduler = FlaxDDIMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)

#flaxddpms_scheduler = FlaxDDPMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)

#flaxpndms_scheduler = FlaxPNDMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)

ddpms_scheduler = DDPMScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
)

SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler,
klms=klms_scheduler,
ddpms=ddpms_scheduler,
#flaxddims=flaxddims_scheduler,
#flaxddpms=flaxddpms_scheduler,
#flaxpndms=flaxpndms_scheduler,
)

with st.session_state["progress_bar_text"].container():
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
Expand Down Expand Up @@ -1721,7 +1746,9 @@ def layout():
#sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
#sampler_name = st.selectbox("Sampling method", sampler_name_list,
#index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
scheduler_name_list = ["klms", "ddim"]
scheduler_name_list = ["klms", "ddim", "ddpms",
#"flaxddims", "flaxddpms", "flaxpndms"
]
scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")

Expand Down

0 comments on commit 0d73421

Please sign in to comment.