From b3f7df82399b27a2128cdf423e933ccf6680df74 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Sat, 22 Jun 2024 15:54:27 +0000 Subject: [PATCH] [fix] better support local ckpt --- opensora/models/stdit/stdit3.py | 5 +++-- opensora/models/vae/vae.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/opensora/models/stdit/stdit3.py b/opensora/models/stdit/stdit3.py index 8703b2d1..bd9672db 100644 --- a/opensora/models/stdit/stdit3.py +++ b/opensora/models/stdit/stdit3.py @@ -448,7 +448,7 @@ def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w): @MODELS.register_module("STDiT3-XL/2") def STDiT3_XL_2(from_pretrained=None, **kwargs): force_huggingface = kwargs.pop("force_huggingface", False) - if force_huggingface or from_pretrained is not None and not os.path.isdir(from_pretrained): + if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained): model = STDiT3.from_pretrained(from_pretrained, **kwargs) else: config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) @@ -460,7 +460,8 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs): @MODELS.register_module("STDiT3-3B/2") def STDiT3_3B_2(from_pretrained=None, **kwargs): - if from_pretrained is not None and not os.path.isdir(from_pretrained): + force_huggingface = kwargs.pop("force_huggingface", False) + if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained): model = STDiT3.from_pretrained(from_pretrained, **kwargs) else: config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs) diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index bf50ec83..9802b02d 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -277,7 +277,7 @@ def OpenSoraVAE_V1_2( scale=scale, ) - if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)): + if force_huggingface or (from_pretrained is not None and not os.path.exists(from_pretrained)): model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs) else: config = VideoAutoencoderPipelineConfig(**kwargs)