Skip to content

Commit

Permalink
Merge pull request #531 from hpcaitech/hotfix/hf-load
Browse files Browse the repository at this point in the history
[fix] better support local ckpt
  • Loading branch information
zhengzangw authored Jun 22, 2024
2 parents 8ccd152 + b3f7df8 commit a6036e4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions opensora/models/stdit/stdit3.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w):
@MODELS.register_module("STDiT3-XL/2")
def STDiT3_XL_2(from_pretrained=None, **kwargs):
force_huggingface = kwargs.pop("force_huggingface", False)
if force_huggingface or from_pretrained is not None and not os.path.isdir(from_pretrained):
if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
Expand All @@ -460,7 +460,8 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs):

@MODELS.register_module("STDiT3-3B/2")
def STDiT3_3B_2(from_pretrained=None, **kwargs):
if from_pretrained is not None and not os.path.isdir(from_pretrained):
force_huggingface = kwargs.pop("force_huggingface", False)
if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion opensora/models/vae/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def OpenSoraVAE_V1_2(
scale=scale,
)

if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)):
if force_huggingface or (from_pretrained is not None and not os.path.exists(from_pretrained)):
model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs)
else:
config = VideoAutoencoderPipelineConfig(**kwargs)
Expand Down

0 comments on commit a6036e4

Please sign in to comment.